結果
問題 |
No.3044 よくあるカエルさん
|
ユーザー |
|
提出日時 | 2025-03-01 02:06:19 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 633 ms / 2,000 ms |
コード長 | 1,722 bytes |
コンパイル時間 | 454 ms |
コンパイル使用メモリ | 82,768 KB |
実行使用メモリ | 78,868 KB |
最終ジャッジ日時 | 2025-03-01 02:06:26 |
合計ジャッジ時間 | 6,351 ms |
ジャッジサーバーID (参考情報) |
judge6 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 20 |
ソースコード
#a*bを計算、O(ha*wa*hb*wb) #初期値は単位行列を使う def calc(a,b): #print("a",a) #print("b",b) ha = len(a) wa = len(a[0]) hb = len(b) wb = len(b[0]) c = [[0]*wb for i in range(ha)] for i in range(ha): for j in range(wb): tmp = 0 for k in range(wa): tmp += a[i][k] * b[k][j] c[i][j] = tmp%mod return c def powcalc(a,N):#aをN乗 cnt = N.bit_length() l = [a] #print(cnt) for _ in range(cnt-1): l.append(calc(l[-1],l[-1])) E = [[0]*len(a) for _ in range(len(a))]#単位行列 for i in range(len(a)): E[i][i] = 1 for i in range(cnt): if (N >> i) & 1: E = calc(E,l[i]) return E mod = 998244353 N,T = map(int, input().split()) k,l = map(int, input().split()) six_rev = pow(6,-1,mod) if N <= T: dp = [0] * N dp[0] = 1 for i in range(N): if i+1<N: dp[i+1] += dp[i] * (k-1) * six_rev dp[i+1] %= mod if i+2<N: dp[i+2] += dp[i] * (l-k) * six_rev dp[i+2] %= mod print(dp[-1]) elif N > T: dp = [0] * T dp[0] = 1 for i in range(T): if i+1<T: dp[i+1] += dp[i] * (k-1) * six_rev dp[i+1] %= mod if i+2<T: dp[i+2] += dp[i] * (l-k) * six_rev dp[i+2] %= mod #print(dp[-1]) G = [[0]*T for _ in range(T)] G[0][0] = (k-1)*six_rev%mod G[0][1] = (l-k)*six_rev%mod G[0][-1] = (7-l)*six_rev%mod for i in range(1,T): G[i][i-1] = 1 G2 = [[i] for i in dp[::-1]] #print(G,G2) GG = powcalc(G,N-T) ans = calc(GG,G2) print(ans[0][0]%mod)