結果

問題 No.2951 Similar to Mex
ユーザー nouka28nouka28
提出日時 2024-04-03 17:36:31
言語 PyPy3
(7.3.15)
結果
RE  
(最新)
AC  
(最初)
実行時間 -
コード長 3,467 bytes
コンパイル時間 554 ms
コンパイル使用メモリ 82,492 KB
実行使用メモリ 84,356 KB
最終ジャッジ日時 2024-10-25 20:50:08
合計ジャッジ時間 5,287 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 47 ms
67,496 KB
testcase_01 AC 50 ms
67,284 KB
testcase_02 AC 96 ms
82,356 KB
testcase_03 AC 83 ms
82,260 KB
testcase_04 AC 61 ms
72,420 KB
testcase_05 AC 88 ms
82,424 KB
testcase_06 AC 58 ms
72,080 KB
testcase_07 AC 64 ms
73,744 KB
testcase_08 AC 60 ms
74,700 KB
testcase_09 AC 56 ms
70,428 KB
testcase_10 AC 62 ms
73,544 KB
testcase_11 RE -
testcase_12 RE -
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 RE -
testcase_18 RE -
testcase_19 RE -
testcase_20 RE -
testcase_21 RE -
testcase_22 RE -
testcase_23 RE -
testcase_24 RE -
testcase_25 RE -
testcase_26 RE -
testcase_27 RE -
testcase_28 RE -
testcase_29 RE -
testcase_30 RE -
testcase_31 RE -
testcase_32 RE -
testcase_33 RE -
testcase_34 RE -
testcase_35 RE -
testcase_36 AC 49 ms
68,632 KB
testcase_37 AC 47 ms
67,112 KB
testcase_38 AC 49 ms
67,144 KB
testcase_39 RE -
testcase_40 RE -
testcase_41 AC 48 ms
67,696 KB
testcase_42 RE -
testcase_43 RE -
testcase_44 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

def f(a:list,x:int):
    a=set(a)
    while x in a:x+=1
    return x

def g(a:list,m:int):
    ret=1;mod=998244353
    for i in range(1,m+1):
        ret*=f(a,i)
        ret%=mod
    return ret

import itertools
def naive(N,M,K):
    ret=0
    mod=998244353
    for v in itertools.product(range(1,M+1),repeat=N):
        ret+=g(list(v),K)
        ret%=mod
    return ret

def fast(N,M,K):
    DP=[[[[0,0]for _ in range(89)]for _ in range(89)]for _ in range(89)]
    bi=[[0]*89 for _ in range(89)]
    bi[0][0]=1
    mod=998244353
    for i in range(85):
        for j in range(85):
            bi[i+1][j]+=bi[i][j]
            bi[i+1][j]%=mod
            bi[i+1][j+1]+=bi[i][j]
            bi[i+1][j+1]%=mod
    
    fit=lambda l,r,x:0 if x<l else (min(r,x)-l+1)
    DP[1][N][0][0]=1
    for i in range(1,M+2):
        for j in range(N+1):
            for k in range(i):
                #DP[i][j][k][flg] : 1..i-1まで見て、j個の空きがあり、現在、i-k...i-1 が全て含まれている、flgはすでに一つ以上のmexが出たかどうか
                if i<=M:
                    for l in range(j+1):
                        if l==0:
                            DP[i+1][j][0][1]+=(DP[i][j][k][0]+DP[i][j][k][1])*pow(i,fit(i-k,i,K),mod)
                            DP[i+1][j][0][1]%=mod
                        else:
                            DP[i+1][j-l][k+1][0]+=DP[i][j][k][0]*bi[j][l]
                            DP[i+1][j-l][k+1][0]%=mod
                            DP[i+1][j-l][k+1][1]+=DP[i][j][k][1]*bi[j][l]
                            DP[i+1][j-l][k+1][1]%=mod
                    pass
                else:
                    if j>0:continue
                    tmp=(DP[i][j][k][0]+DP[i][j][k][1])*pow(i,fit(i-k,i,K),mod)
                    for l in range(i+1,K+1):
                        tmp*=l
                        tmp%=mod
                    DP[i+1][0][0][1]+=tmp
                    DP[i+1][0][0][1]%=mod
    return DP[M+2][0][0][1]

def fast2(N,M,K):
    DP=[[[0]*89 for _ in range(89)]for _ in range(89)]
    bi=[[0]*89 for _ in range(89)]
    bi[0][0]=1
    mod=998244353
    for i in range(85):
        for j in range(85):
            bi[i+1][j]+=bi[i][j]
            bi[i+1][j]%=mod
            bi[i+1][j+1]+=bi[i][j]
            bi[i+1][j+1]%=mod
    
    fit=lambda l,r,x:0 if x<l else (min(r,x)-l+1)
    DP[1][N][0]=1
    for i in range(1,M+2):
        for j in range(N+1):
            for k in range(i):
                #DP[i][j][k][flg] : 1..i-1まで見て、j個の空きがあり、現在、i-k...i-1 が全て含まれている、flgはすでに一つ以上のmexが出たかどうか
                if i<=M:
                    for l in range(j+1):
                        if l==0:
                            DP[i+1][j][0]+=DP[i][j][k]*pow(i,fit(i-k,i,K),mod)
                            DP[i+1][j][0]%=mod
                        else:
                            DP[i+1][j-l][k+1]+=DP[i][j][k]*bi[j][l]
                            DP[i+1][j-l][k+1]%=mod
                    pass
                else:
                    if j>0:continue
                    tmp=(DP[i][j][k])*pow(i,fit(i-k,i,K),mod)
                    for l in range(i+1,K+1):
                        tmp*=l
                        tmp%=mod
                    DP[i+1][0][0]+=tmp
                    DP[i+1][0][0]%=mod
    return DP[M+2][0][0]

def main():
    N,M,K=map(int,input().split())
    print(fast2(N,M,K))    
    
    
main()
0