結果

問題 No.2898 Update Max
ユーザー convexineqconvexineq
提出日時 2024-09-23 22:48:21
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 2,557 bytes
コンパイル時間 3,452 ms
コンパイル使用メモリ 82,004 KB
実行使用メモリ 114,316 KB
最終ジャッジ日時 2024-09-23 22:54:24
合計ジャッジ時間 12,870 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 43 ms
62,256 KB
testcase_01 AC 44 ms
61,956 KB
testcase_02 AC 44 ms
61,788 KB
testcase_03 AC 75 ms
74,032 KB
testcase_04 AC 76 ms
73,340 KB
testcase_05 AC 62 ms
73,108 KB
testcase_06 AC 61 ms
72,524 KB
testcase_07 AC 89 ms
72,232 KB
testcase_08 AC 171 ms
107,940 KB
testcase_09 AC 166 ms
107,864 KB
testcase_10 AC 169 ms
107,488 KB
testcase_11 AC 201 ms
108,004 KB
testcase_12 AC 168 ms
107,748 KB
testcase_13 AC 174 ms
108,460 KB
testcase_14 AC 175 ms
108,692 KB
testcase_15 AC 202 ms
108,732 KB
testcase_16 AC 174 ms
108,660 KB
testcase_17 AC 178 ms
108,448 KB
testcase_18 RE -
testcase_19 RE -
testcase_20 RE -
testcase_21 RE -
testcase_22 RE -
testcase_23 RE -
testcase_24 RE -
testcase_25 RE -
testcase_26 RE -
testcase_27 RE -
testcase_28 RE -
testcase_29 AC 71 ms
62,596 KB
testcase_30 AC 44 ms
62,460 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

class BIT: #0-indexed
    __slots__ = ["size", "tree","depth","n0"]
    def __init__(self, n):
        self.size = n
        self.tree = [0]*(n+1)
        self.depth = n.bit_length()
        self.n0 = 1<<self.depth

    def get_sum(self, i): #a_0 + ... + a_{i} #閉区間
        s = 0; i += 1
        while i > 0:
            s += self.tree[i]
            i -= i & -i
        return s

    def range_sum(self,l,r): #a_l + ... + a_r 閉区間
        return self.get_sum(r) - self.get_sum(l-1) 

    def range_sum_larger(self,l): #a_l + ... (端まで)
        return self.get_sum(self.size-1) - (self.get_sum(l-1) if l else 0)
    
    def add(self, i, x):
        i += 1
        while i <= self.size:
            self.tree[i] += x
            i += i & -i

SIZE=10**5+1; MOD=998244353 #10**9+7 #ここを変更する

inv = [0]*SIZE  # inv[j] = j^{-1} mod MOD
fac = [0]*SIZE  # fac[j] = j! mod MOD
finv = [0]*SIZE # finv[j] = (j!)^{-1} mod MOD
fac[0] = fac[1] = 1
finv[0] = finv[1] = 1
for i in range(2,SIZE):
    fac[i] = fac[i-1]*i%MOD
finv[-1] = pow(fac[-1],MOD-2,MOD)
for i in range(SIZE-1,0,-1):
    finv[i-1] = finv[i]*i%MOD
    inv[i] = finv[i]*fac[i-1]%MOD

def choose(n,r): # nCk mod MOD の計算
    if 0 <= r <= n:
        return (fac[n]*finv[r]%MOD)*finv[n-r]%MOD
    else:
        return 0
def perm(n,r): # nPr mod MOD
    if 0 <= r <= n:
        return fac[n]*finv[n-r]%MOD
    else:
        return 0

import sys
readline = sys.stdin.readline

#n = int(readline())
#n,Q = map(int,readline().split())

n = int(readline())
*a, = map(int,readline().split())

# from itertools import permutations
# cnt = [0]*n
# for lst in permutations(list(range(1,n+1))):
#     for i in range(n):
#         if a[i] != -1 and lst[i] != a[i]: break
#     else:
#         m = 0
#         for i in range(n):
#             m = max(m,lst[i])
#             if m==lst[i]: cnt[i] += 1
# print(cnt)


bit = BIT(n+1)
for i in range(1,n+1):
    bit.add(i,1)
    
for ai in a:
    if ai != -1:
        bit.add(ai,-1)


tot = a.count(-1)
ans = 0
m = 0
cnt = 0
for i,ai in enumerate(a):
    if ai == -1:
        cnt += 1
        c = bit.get_sum(m)
        r = (fac[tot]-perm(c,cnt)*fac[tot-cnt])%MOD*inv[cnt]
        #print((perm(m-i-1+cnt,cnt)*fac[tot-cnt])%MOD,r%MOD)
        #print(r%MOD)
        ans = (ans+r)%MOD
    else:
        if m < ai:
            c = bit.get_sum(ai)
            r = perm(c,cnt)*fac[tot-cnt]
            #print(r)
            ans = (ans+r)%MOD
            m = ai        
        else:
            r = 0
        #print(r%MOD)


print(ans)




    
0