結果

問題 No.1712 Read and Pile
ユーザー chineristAC
提出日時 2021-03-23 22:03:58
言語 PyPy3
(7.3.15)
結果
WA  
(最新)
AC  
(最初)
実行時間 -
コード長 2,236 bytes
コンパイル時間 260 ms
コンパイル使用メモリ 82,296 KB
実行使用メモリ 132,248 KB
最終ジャッジ日時 2024-09-17 16:53:31
合計ジャッジ時間 17,819 ms
ジャッジサーバーID
(参考情報)
judge6 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1 WA * 2
other AC * 6 WA * 34
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

class BIT():
def __init__(self,n,mod=None):
self.BIT=[0]*(n+1)
self.num=n
self.mod = mod
def query(self,idx):
res_sum = 0
while idx > 0:
res_sum += self.BIT[idx]
if self.mod:
res_sum %= self.mod
idx -= idx&(-idx)
return res_sum
#Ai += x O(logN)
def update(self,idx,x):
while idx <= self.num:
self.BIT[idx] += x
if self.mod:
self.BIT[idx] %= self.mod
idx += idx&(-idx)
return
import sys,random,bisect
from collections import deque,defaultdict
from heapq import heapify,heappop,heappush
from itertools import permutations
from math import log,gcd
input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())
mod = 998244353
i2 = pow(2,mod-2,mod)
def solve(N,M,A):
p = ((N-2)*pow(N,mod-2,mod)) % mod
ip = pow(p,mod-2,mod)
pow_p = [1 for i in range(N+M+1)]
pow_ip = [1 for i in range(N+M+1)]
for i in range(1,N+M+1):
pow_p[i] = (pow_p[i-1] * p) % mod
pow_ip[i] = (pow_ip[i-1] * ip) % mod
A = [N-i for i in range(N)] + A
A = [-1] + A
zero = [A[i]==0 for i in range(N+M+1)]
for i in range(1,N+M+1):
zero[i] += zero[i-1]
lastappear = [N-i+1 for i in range(N+1)]
res = 0
bit_p = BIT(N+M,mod=998244353)
bit_cnt = BIT(N+M)
for i in range(1,N+1):
bit_cnt.update(i,1)
bit_p.update(i,pow_ip[zero[i]])
for i in range(N+1,N+M+1):
if A[i]==0:
res += (N-1)
res %= mod
else:
pre = lastappear[A[i]]
k = bit_cnt.query(pre-1)
res = (res + N-1) % mod
res += (bit_p.query(i)-bit_p.query(pre)) * pow_p[zero[i]] % mod
res %= mod
res -= k * pow_p[zero[i]-zero[pre]] % mod
res %= mod
bit_cnt.update(pre,-1)
bit_p.update(pre,-pow_ip[zero[pre]])
lastappear[A[i]] = i
bit_cnt.update(i,1)
bit_p.update(i,pow_ip[zero[i]])
res *= i2
res %= mod
res += M
res %= mod
return res
N,M = mi()
A = li()
print(solve(N,M,A))
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0