結果
| 問題 |
No.2951 Similar to Mex
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2024-04-03 17:36:31 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
RE
(最新)
AC
(最初)
|
| 実行時間 | - |
| コード長 | 3,467 bytes |
| コンパイル時間 | 554 ms |
| コンパイル使用メモリ | 82,492 KB |
| 実行使用メモリ | 84,356 KB |
| 最終ジャッジ日時 | 2024-10-25 20:50:08 |
| 合計ジャッジ時間 | 5,287 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 12 RE * 30 |
ソースコード
def f(a:list,x:int):
a=set(a)
while x in a:x+=1
return x
def g(a:list,m:int):
ret=1;mod=998244353
for i in range(1,m+1):
ret*=f(a,i)
ret%=mod
return ret
import itertools
def naive(N,M,K):
ret=0
mod=998244353
for v in itertools.product(range(1,M+1),repeat=N):
ret+=g(list(v),K)
ret%=mod
return ret
def fast(N,M,K):
DP=[[[[0,0]for _ in range(89)]for _ in range(89)]for _ in range(89)]
bi=[[0]*89 for _ in range(89)]
bi[0][0]=1
mod=998244353
for i in range(85):
for j in range(85):
bi[i+1][j]+=bi[i][j]
bi[i+1][j]%=mod
bi[i+1][j+1]+=bi[i][j]
bi[i+1][j+1]%=mod
fit=lambda l,r,x:0 if x<l else (min(r,x)-l+1)
DP[1][N][0][0]=1
for i in range(1,M+2):
for j in range(N+1):
for k in range(i):
#DP[i][j][k][flg] : 1..i-1まで見て、j個の空きがあり、現在、i-k...i-1 が全て含まれている、flgはすでに一つ以上のmexが出たかどうか
if i<=M:
for l in range(j+1):
if l==0:
DP[i+1][j][0][1]+=(DP[i][j][k][0]+DP[i][j][k][1])*pow(i,fit(i-k,i,K),mod)
DP[i+1][j][0][1]%=mod
else:
DP[i+1][j-l][k+1][0]+=DP[i][j][k][0]*bi[j][l]
DP[i+1][j-l][k+1][0]%=mod
DP[i+1][j-l][k+1][1]+=DP[i][j][k][1]*bi[j][l]
DP[i+1][j-l][k+1][1]%=mod
pass
else:
if j>0:continue
tmp=(DP[i][j][k][0]+DP[i][j][k][1])*pow(i,fit(i-k,i,K),mod)
for l in range(i+1,K+1):
tmp*=l
tmp%=mod
DP[i+1][0][0][1]+=tmp
DP[i+1][0][0][1]%=mod
return DP[M+2][0][0][1]
def fast2(N,M,K):
DP=[[[0]*89 for _ in range(89)]for _ in range(89)]
bi=[[0]*89 for _ in range(89)]
bi[0][0]=1
mod=998244353
for i in range(85):
for j in range(85):
bi[i+1][j]+=bi[i][j]
bi[i+1][j]%=mod
bi[i+1][j+1]+=bi[i][j]
bi[i+1][j+1]%=mod
fit=lambda l,r,x:0 if x<l else (min(r,x)-l+1)
DP[1][N][0]=1
for i in range(1,M+2):
for j in range(N+1):
for k in range(i):
#DP[i][j][k][flg] : 1..i-1まで見て、j個の空きがあり、現在、i-k...i-1 が全て含まれている、flgはすでに一つ以上のmexが出たかどうか
if i<=M:
for l in range(j+1):
if l==0:
DP[i+1][j][0]+=DP[i][j][k]*pow(i,fit(i-k,i,K),mod)
DP[i+1][j][0]%=mod
else:
DP[i+1][j-l][k+1]+=DP[i][j][k]*bi[j][l]
DP[i+1][j-l][k+1]%=mod
pass
else:
if j>0:continue
tmp=(DP[i][j][k])*pow(i,fit(i-k,i,K),mod)
for l in range(i+1,K+1):
tmp*=l
tmp%=mod
DP[i+1][0][0]+=tmp
DP[i+1][0][0]%=mod
return DP[M+2][0][0]
def main():
N,M,K=map(int,input().split())
print(fast2(N,M,K))
main()