結果

問題 No.315 世界のなんとか3.5
ユーザー lam6er
提出日時 2025-04-16 16:28:57
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,560 bytes
コンパイル時間 579 ms
コンパイル使用メモリ 82,476 KB
実行使用メモリ 93,180 KB
最終ジャッジ日時 2025-04-16 16:29:38
合計ジャッジ時間 7,881 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 14 TLE * 1 -- * 21
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 10**9 + 7

def subtract_one(s):
    s_list = list(s)
    i = len(s_list) - 1
    while i >= 0 and s_list[i] == '0':
        s_list[i] = '9'
        i -= 1
    if i < 0:
        return '0'
    s_list[i] = str(int(s_list[i]) - 1)
    if s_list[0] == '0' and len(s_list) > 1:
        return ''.join(s_list[1:])
    return ''.join(s_list)

def count_aho(X, P):
    digits = list(map(int, X))
    n = len(digits)
    state_size = 3 * 2 * 2  # mod3 (3), has3 (2), tight (2)
    prev_dp = [0] * state_size
    prev_dp[0 * (2 * 2) + 0 * 2 + 1] = 1  # mod3=0, has3=0, tight=1

    for i in range(n):
        curr_dp = [0] * state_size
        for state_idx in range(state_size):
            cnt = prev_dp[state_idx]
            if cnt == 0:
                continue
            mod3 = state_idx // (2 * 2)
            has3 = (state_idx // 2) % 2
            tight = state_idx % 2

            upper = digits[i] if tight else 9
            for d in range(0, upper + 1):
                new_tight = 1 if (tight and d == upper) else 0
                new_mod3 = (mod3 + d) % 3
                new_has3 = has3 | (d == 3)
                new_state = new_mod3 * (2 * 2) + new_has3 * 2 + new_tight
                curr_dp[new_state] = (curr_dp[new_state] + cnt) % MOD
        prev_dp = curr_dp

    total = 0
    for state_idx in range(state_size):
        cnt = prev_dp[state_idx]
        if cnt == 0:
            continue
        mod3 = state_idx // (2 * 2)
        has3 = (state_idx // 2) % 2
        if mod3 == 0 or has3:
            total = (total + cnt) % MOD
    return total

def count_aho_and_p(X, P):
    digits = list(map(int, X))
    n = len(digits)
    state_size = 3 * 2 * P * 2  # mod3 (3), has3 (2), modP (P), tight (2)
    prev_dp = [0] * state_size
    prev_dp[0 * (2 * P * 2) + 0 * (P * 2) + 0 * 2 + 1] = 1  # mod3=0, has3=0, modP=0, tight=1

    for i in range(n):
        curr_dp = [0] * state_size
        for state_idx in range(state_size):
            cnt = prev_dp[state_idx]
            if cnt == 0:
                continue
            mod3 = (state_idx // (2 * P * 2)) % 3
            has3 = (state_idx // (P * 2)) % 2
            modP = (state_idx // 2) % P
            tight = state_idx % 2

            upper = digits[i] if tight else 9
            for d in range(0, upper + 1):
                new_tight = 1 if (tight and d == upper) else 0
                new_mod3 = (mod3 + d) % 3
                new_has3 = has3 | (d == 3)
                new_modP = (modP * 10 + d) % P
                new_state = new_mod3 * (2 * P * 2) + new_has3 * (P * 2) + new_modP * 2 + new_tight
                curr_dp[new_state] = (curr_dp[new_state] + cnt) % MOD
        prev_dp = curr_dp

    total = 0
    for state_idx in range(state_size):
        cnt = prev_dp[state_idx]
        if cnt == 0:
            continue
        mod3 = (state_idx // (2 * P * 2)) % 3
        has3 = (state_idx // (P * 2)) % 2
        modP = (state_idx // 2) % P
        if (mod3 == 0 or has3) and modP == 0:
            total = (total + cnt) % MOD
    return total

def main():
    A, B, P = input().split()
    P = int(P)
    A_minus_1 = subtract_one(A)

    count_B_aho = count_aho(B, P)
    count_Aminus1_aho = count_aho(A_minus_1, P)
    total_aho = (count_B_aho - count_Aminus1_aho) % MOD

    count_B_aho_p = count_aho_and_p(B, P)
    count_Aminus1_aho_p = count_aho_and_p(A_minus_1, P)
    total_aho_p = (count_B_aho_p - count_Aminus1_aho_p) % MOD

    ans = (total_aho - total_aho_p) % MOD
    print(ans if ans >= 0 else ans + MOD)

if __name__ == "__main__":
    main()
0