結果

問題 No.3478 XOR-Folding Primes
コンテスト
ユーザー 👑 loop0919
提出日時 2026-01-09 22:23:15
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
AC  
実行時間 2,399 ms / 4,000 ms
コード長 1,807 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 646 ms
コンパイル使用メモリ 85,100 KB
実行使用メモリ 321,608 KB
最終ジャッジ日時 2026-03-20 20:50:31
合計ジャッジ時間 13,421 ms
ジャッジサーバーID
(参考情報)
judge3_0 / judge2_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 8
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

from bisect import bisect_right

MOD = 998244353
LIMIT = 10**7

prime_count = [0] * (LIMIT + 1)
good_primes = [0] * (LIMIT + 1)


def prepare():
    global prime_count, good_primes

    is_prime = [True] * (LIMIT + 1)
    is_prime[0] = is_prime[1] = False

    for i in range(2, LIMIT + 1):
        if not is_prime[i]:
            continue

        for j in range(2 * i, LIMIT + 1, i):
            is_prime[j] = False

    for i in range(2, LIMIT + 1):
        prime_count[i] = prime_count[i - 1] + (1 if is_prime[i] else 0)

    for i in range(1, LIMIT - 1):
        j = i ^ 2
        good_primes[i] = good_primes[i - 1]

        if j < i and is_prime[i] and is_prime[j]:
            good_primes[i] += 2


def multiple(A: list[list[int]], B: list[list[int]]):
    h_1, w_1 = len(A), len(A[0])
    h_2, w_2 = len(B), len(B[0])

    assert w_1 == h_2

    result = [[0] * w_2 for _ in range(h_1)]

    for i in range(h_1):
        for j in range(w_2):
            for k in range(w_1):
                result[i][j] = (result[i][j] + A[i][k] * B[k][j]) % MOD

    return result


def power(A: list[list[int]], e: int):
    n = len(A)
    result = [[1 if i == j else 0 for j in range(n)] for i in range(n)]

    for i in range(e.bit_length() - 1, -1, -1):
        result = multiple(result, result)
        if (e >> i) & 1 == 1:
            result = multiple(result, A)

    return result


def solve():
    N, M = [int(s) for s in input().split()]

    if N == 1:
        print(prime_count[M])
        return

    count = good_primes[M]

    dp = [[1], [count]]
    matrix = [[0, 1], [count, 1]]

    result = multiple(power(matrix, N - 1), dp)
    ans = (result[0][0] + result[1][0]) % MOD

    print(ans)


if __name__ == "__main__":
    prepare()

    T = int(input())
    for _ in range(T):
        solve()
0