結果

問題 No.732 3PrimeCounting
ユーザー onakasuitacity
提出日時 2020-10-17 14:52:46
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 817 ms / 3,000 ms
コード長 1,770 bytes
コンパイル時間 246 ms
コンパイル使用メモリ 82,432 KB
実行使用メモリ 272,428 KB
最終ジャッジ日時 2024-07-21 02:29:53
合計ジャッジ時間 17,871 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 89
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
INF = 1 << 60
MOD = 998244353
sys.setrecursionlimit(2147483647)
input = lambda:sys.stdin.readline().rstrip()

prime = 998244353
root = 3
def _fmt(A, inverse = False):
    N = len(A)
    logN = (N - 1).bit_length()
    base = pow(root, (prime - 1) // N * (1 - 2 * inverse) % (prime - 1), prime)
    step = N
    for k in range(logN):
        step >>= 1
        w = pow(base, step, prime)
        wj = 1
        nA = [0] * N
        for j in range(1 << k):
            for i in range(1 << logN - k - 1):
                s, t = i + j * step, i + j * step + (N >> 1)
                ps, pt = i + j * step * 2, i + j * step * 2 + step
                nA[s], nA[t] = (A[ps] + A[pt] * wj) % prime, (A[ps] - A[pt] * wj) % prime
            wj = (wj * w) % prime
        A = nA
    return A

def convolution(f, g):
    N = 1 << (len(f) + len(g) - 2).bit_length()
    Ff, Fg = _fmt(f + [0] * (N - len(f))), _fmt(g + [0] * (N - len(g)))
    N_inv = pow(N, prime - 2, prime)
    fg = _fmt([a * b % prime * N_inv % prime for a, b in zip(Ff, Fg)], inverse = True)
    del fg[len(f) + len(g) - 1:]
    return fg

def resolve():
    n = int(input())
    N = 3 * n
    primes = []
    sieve = list(range(N + 1))
    for i in range(2, N + 1):
        if sieve[i] == i:
            primes.append(i)
        for p in primes:
            if sieve[i] < p or i * p > N:
                break
            sieve[i * p] = p

    f = [0] * (n + 1)
    g = [0] * (2 * n + 1)
    for p in primes:
        if p <= n:
            f[p] += 1
            g[p * 2] += 1

    a = 0
    f3 = convolution(f, convolution(f, f))
    for p in primes:
        a += f3[p]

    b = 0
    fg = convolution(f, g)
    for p in primes:
        b += fg[p]

    ans = (a - 3 * b) // 6
    print(ans)
resolve()
0