結果

問題 No.3478 XOR-Folding Primes
コンテスト
ユーザー 👑 loop0919
提出日時 2026-01-07 21:31:23
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
AC  
実行時間 2,638 ms / 4,000 ms
コード長 1,795 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 245 ms
コンパイル使用メモリ 85,736 KB
実行使用メモリ 243,336 KB
最終ジャッジ日時 2026-03-20 20:50:19
合計ジャッジ時間 14,965 ms
ジャッジサーバーID
(参考情報)
judge2_0 / judge1_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 8
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

from bisect import bisect_right

MOD = 998244353

prime_count: list
good_primes: list


def prepare(limit=10**7):
    global prime_count, good_primes

    is_prime = [True] * (limit + 1)
    is_prime[0] = is_prime[1] = False

    for i in range(2, limit + 1):
        if not is_prime[i]:
            continue

        for j in range(2 * i, limit + 1, i):
            is_prime[j] = False

    prime_count = [0] * (limit + 1)
    for i in range(2, limit + 1):
        prime_count[i] = prime_count[i - 1] + (1 if is_prime[i] else 0)

    good_primes = []
    for i in range(1, limit - 1, 4):
        if is_prime[i] and is_prime[i + 2]:
            good_primes.append(i)


def multiple(A: list[list[int]], B: list[list[int]]):
    h_1, w_1 = len(A), len(A[0])
    h_2, w_2 = len(B), len(B[0])

    assert w_1 == h_2

    result = [[0] * w_2 for _ in range(h_1)]

    for i in range(h_1):
        for j in range(w_2):
            for k in range(w_1):
                result[i][j] = (result[i][j] + A[i][k] * B[k][j]) % MOD

    return result


def power(A: list[list[int]], e: int):
    n = len(A)
    result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]

    for i in range(e.bit_length() - 1, -1, -1):
        result = multiple(result, result)
        if (e >> i) & 1 == 1:
            result = multiple(result, A)

    return result


def solve():
    N, M = [int(s) for s in input().split()]

    if N == 1:
        print(prime_count[M])
        return

    count = bisect_right(good_primes, M - 2)

    dp = [[1], [2 * count]]
    matrix = [[0, 1], [2 * count, 1]]

    result = multiple(power(matrix, N - 1), dp)
    ans = (result[0][0] + result[1][0]) % MOD

    print(ans)


if __name__ == "__main__":
    prepare()

    T = int(input())
    for _ in range(T):
        solve()
0