結果

問題 No.1783 Remix Sum
ユーザー lam6er
提出日時 2025-04-09 21:01:20
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,027 bytes
コンパイル時間 399 ms
コンパイル使用メモリ 82,684 KB
実行使用メモリ 791,368 KB
最終ジャッジ日時 2025-04-09 21:02:56
合計ジャッジ時間 14,630 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other MLE * 1 -- * 75
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 120586241

def main():
    N, K, M, T = map(int, sys.stdin.readline().split())
    a = list(map(int, sys.stdin.readline().split()))
    
    # Convert each a to K-digit list, LSB first
    digits_list = []
    for num in a:
        d = []
        for _ in range(K):
            d.append(num % 10)
            num //= 10
        digits_list.append(tuple(d))
    
    # Preprocess frequency of each K-digit number
    freq = {}
    for dl in digits_list:
        if dl not in freq:
            freq[dl] = 0
        freq[dl] += 1
    
    from collections import defaultdict

    # Precompute all states and transitions
    state_map = {}
    idx = 0
    state_list = []
    for s in range(10**K):
        parts = []
        tmp = s
        for _ in range(K):
            parts.append(tmp % 10)
            tmp //= 10
        state_map[tuple(parts)] = idx
        state_list.append(parts)
        idx += 1
    
    # Build transitions: next_state_func[s] = {next_s: count}
    next_state_func = [defaultdict(int) for _ in range(10**K)]
    
    for current_s in range(10**K):
        parts = state_list[current_s]
        s_i = parts[:T]
        t_j = parts[T:]
        
        for dl, cnt in freq.items():
            valid = True
            new_s_i = list(s_i)
            for i in range(T):
                a_i = dl[i]
                if s_i[i] + a_i > 9:
                    valid = False
                    break
                new_s_i[i] = s_i[i] + a_i
            if not valid:
                continue
            
            new_t_j = list(t_j)
            for j in range(T, K):
                new_t_j[j - T] = (t_j[j - T] + dl[j]) % 10
            
            new_parts = new_s_i + new_t_j
            new_s = state_map[tuple(new_parts)]
            next_state_func[current_s][new_s] = (next_state_func[current_s][new_s] + cnt) % MOD
    
    # Matrix exponentiation
    def multiply(a, b):
        result = [defaultdict(int) for _ in range(10**K)]
        for i in range(10**K):
            if not a[i]:
                continue
            for j in a[i]:
                cnt_j = a[i][j]
                if cnt_j == 0:
                    continue
                for k in b[j]:
                    result[i][k] = (result[i][k] + cnt_j * b[j][k]) % MOD
        return result
    
    def matrix_pow(mat, power):
        result = [defaultdict(int) for _ in range(10**K)]
        for i in range(10**K):
            result[i][i] = 1
        
        while power > 0:
            if power % 2 == 1:
                result = multiply(result, mat)
            mat = multiply(mat, mat)
            power //= 2
        return result
    
    transitions = next_state_func
    mat = transitions
    mat_pow = matrix_pow(mat, M)
    
    initial_state = [0]*K
    initial_idx = state_map[tuple(initial_state)]
    res = [0]*(10**K)
    for s in range(10**K):
        res[s] = mat_pow[initial_idx].get(s, 0)
    
    for i in range(10**K):
        print(res[i] % MOD)

if __name__ == "__main__":
    main()
0