結果
| 問題 |
No.1712 Read and Pile
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-03-10 19:23:17 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
(最新)
AC
(最初)
|
| 実行時間 | - |
| コード長 | 2,150 bytes |
| コンパイル時間 | 309 ms |
| コンパイル使用メモリ | 82,176 KB |
| 実行使用メモリ | 118,912 KB |
| 最終ジャッジ日時 | 2024-09-17 16:52:23 |
| 合計ジャッジ時間 | 18,584 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge6 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 WA * 2 |
| other | AC * 6 WA * 34 |
ソースコード
class BIT():
def __init__(self,n,mod=None):
self.BIT=[0]*(n+1)
self.num=n
self.mod = mod
def query(self,idx):
res_sum = 0
while idx > 0:
res_sum += self.BIT[idx]
if self.mod:
res_sum %= self.mod
idx -= idx&(-idx)
return res_sum
#Ai += x O(logN)
def update(self,idx,x):
while idx <= self.num:
self.BIT[idx] += x
if self.mod:
self.BIT[idx] %= self.mod
idx += idx&(-idx)
return
import sys,random,bisect
from collections import deque,defaultdict
from heapq import heapify,heappop,heappush
from itertools import permutations
from math import log,gcd
input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())
mod = 998244353
i2 = pow(2,mod-2,mod)
def solve(N,M,A):
p = ((N-2)*pow(N,mod-2,mod)) % mod
ip = pow(p,mod-2,mod)
A = [N-i for i in range(N)] + A
A = [-1] + A
zero = [A[i]==0 for i in range(N+M+1)]
for i in range(1,N+M+1):
zero[i] += zero[i-1]
Z = zero[-1]
lastappear = [N-i+1 for i in range(N+1)]
res = 0
bit_p = BIT(N+M,mod=998244353)
bit_cnt = BIT(N+M)
for i in range(1,N+1):
bit_cnt.update(i,1)
bit_p.update(i,pow(ip,zero[i],mod))
for i in range(N+1,N+M+1):
if A[i]==0:
res += (N-1)
res %= mod
else:
pre = lastappear[A[i]]
k = bit_cnt.query(pre-1)
res += (N-1-k) % mod
res %= mod
res += (bit_p.query(i)-bit_p.query(pre)) * pow(p,zero[i],mod) % mod
res %= mod
res += k % mod
res %= mod
res -= k * pow(p,zero[i]-zero[pre],mod) % mod
res %= mod
bit_cnt.update(pre,-1)
bit_p.update(pre,-pow(ip,zero[pre],mod))
lastappear[A[i]] = i
bit_cnt.update(i,1)
bit_p.update(i,pow(ip,zero[i],mod))
res *= i2
res %= mod
res += M
res %= mod
return res
N,M = mi()
A = li()
print(solve(N,M,A))