結果

問題 No.2807 Have Another Go (Easy)
ユーザー lam6er
提出日時 2025-04-15 22:39:18
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,758 bytes
コンパイル時間 300 ms
コンパイル使用メモリ 81,724 KB
実行使用メモリ 61,024 KB
最終ジャッジ日時 2025-04-15 22:40:54
合計ジャッジ時間 4,629 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 1 TLE * 1 -- * 44
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr]); ptr +=1
    M = int(input[ptr]); ptr +=1
    k = int(input[ptr]); ptr +=1
    C = list(map(int, input[ptr:ptr+k]))
    ptr +=k
    
    # Precompute total valid sequences
    # For M=2, sum >=2N
    # We need to compute the total number of sequences that reach sum >=2N
    # Using dynamic programming for phases 0 and 1
    # dp0[m]: number of ways to reach sum m in phase 0 (0 <= m < N) and eventually reach >=2N
    # dp1[m]: number of ways to reach sum m in phase 1 (0 <= m < N) and eventually reach >=2N
    
    dp0 = [0] * N
    dp1 = [0] * N
    
    # Compute dp1 first
    for m in reversed(range(N)):
        # dp1[m] = sum_{s=1}^6 [ if m + s < N then dp1[m+s] else 1 ]
        res = 0
        for s in range(1, 7):
            if m + s < N:
                res += dp1[(m + s) % N]
            else:
                res += 1
        dp1[m] = res % MOD
    
    # Compute dp0
    for m in reversed(range(N)):
        res = 0
        for s in range(1, 7):
            if m + s < N:
                res += dp0[(m + s) % N]
            elif m + s < 2 * N:
                nm = (m + s - N) % N
                res += dp1[nm]
            else:
                res += 1
        dp0[m] = res % MOD
    
    total = dp0[0]
    
    # Now handle each C_i
    for c in C:
        # Compute forbidden count where no partial sum is congruent to c mod N
        # Use similar DP but exclude transitions to c
        
        # dp0f[m]: forbidden count for phase 0
        # dp1f[m]: forbidden count for phase 1
        dp0f = [0] * N
        dp1f = [0] * N
        
        # Compute dp1f
        for m in reversed(range(N)):
            if m == c:
                continue
            res = 0
            for s in range(1, 7):
                nm = (m + s) % N
                if m + s < N:
                    if nm != c:
                        res += dp1f[nm]
                else:
                    res += 1
            dp1f[m] = res % MOD
        
        # Compute dp0f
        for m in reversed(range(N)):
            if m == c:
                continue
            res = 0
            for s in range(1, 7):
                if m + s < N:
                    nm = (m + s) % N
                    if nm != c:
                        res += dp0f[nm]
                elif m + s < 2 * N:
                    nm = (m + s - N) % N
                    if nm != c:
                        res += dp1f[nm]
                else:
                    res += 1
            dp0f[m] = res % MOD
        
        forbidden = dp0f[0] if 0 != c else 0
        ans = (total - forbidden) % MOD
        print(ans)
    
if __name__ == '__main__':
    main()
0