結果

問題 No.3044 よくあるカエルさん
ユーザー N-noa21
提出日時 2025-03-01 02:05:29
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 1,721 bytes
コンパイル時間 518 ms
コンパイル使用メモリ 82,656 KB
実行使用メモリ 79,348 KB
最終ジャッジ日時 2025-03-01 02:05:36
合計ジャッジ時間 6,545 ms
ジャッジサーバーID
(参考情報)
judge2 / judge6
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other WA * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

#a*bを計算、O(ha*wa*hb*wb)
#初期値は単位行列を使う
def calc(a,b):
    #print("a",a)
    #print("b",b)
    ha = len(a)
    wa = len(a[0])
    hb = len(b)
    wb = len(b[0])

    c = [[0]*wb for i in range(ha)]

    for i in range(ha):

        for j in range(wb):
            tmp = 0
            for k in range(wa):
                tmp += a[i][k] * b[k][j]
            c[i][j] = tmp%mod
    return c

def powcalc(a,N):#aをN乗
    cnt = N.bit_length()
    l = [a]
    print(cnt)
    for _ in range(cnt-1):
        l.append(calc(l[-1],l[-1]))
    
    E = [[0]*len(a) for _ in range(len(a))]#単位行列
    for i in range(len(a)):
        E[i][i] = 1
    
    for i in range(cnt):
        if (N >> i) & 1:
            E = calc(E,l[i])
    return E

mod = 998244353
N,T = map(int, input().split())
k,l = map(int, input().split())
six_rev = pow(6,-1,mod)
if N <= T:
    dp = [0] * N
    dp[0] = 1

    for i in range(N):
        if i+1<N:
            dp[i+1] += dp[i] * (k-1) * six_rev
            dp[i+1] %= mod
        if i+2<N:
            dp[i+2] += dp[i] * (l-k) * six_rev
            dp[i+2] %= mod
    print(dp[-1])
elif N > T:
    dp = [0] * T
    dp[0] = 1

    for i in range(T):
        if i+1<T:
            dp[i+1] += dp[i] * (k-1) * six_rev
            dp[i+1] %= mod
        if i+2<T:
            dp[i+2] += dp[i] * (l-k) * six_rev
            dp[i+2] %= mod
    #print(dp[-1])

    G = [[0]*T for _ in range(T)]

    G[0][0] = (k-1)*six_rev%mod
    G[0][1] = (l-k)*six_rev%mod
    G[0][-1] = (7-l)*six_rev%mod

    for i in range(1,T):
        G[i][i-1] = 1
    G2 = [[i] for i in dp[::-1]]
    #print(G,G2)
    GG = powcalc(G,N-T)

    ans = calc(GG,G2)
    print(ans[0][0]%mod)

    




0