結果

問題 No.2272 多項式乗算 mod 258280327
ユーザー qwewe
提出日時 2025-05-14 12:59:51
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 8,903 bytes
コンパイル時間 339 ms
コンパイル使用メモリ 82,228 KB
実行使用メモリ 238,472 KB
最終ジャッジ日時 2025-05-14 13:01:53
合計ジャッジ時間 10,032 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 27 WA * 3 TLE * 1 -- * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

# Fast modular exponentiation
def fast_pow(base, power, modulus):
    """Computes (base^power) % modulus efficiently."""
    result = 1
    base %= modulus
    while power > 0:
        if power % 2 == 1:
            result = (result * base) % modulus
        base = (base * base) % modulus
        power >>= 1 # Use right shift for integer division by 2
    return result

# Modular inverse using Fermat's Little Theorem
def mod_inverse(a, m):
    """Computes the modular inverse of a modulo m using Fermat's Little Theorem.
    Assumes m is prime and a is not divisible by m."""
    # Check if inverse exists (a is not 0 mod m)
    if a % m == 0:
        # This case should not happen if used correctly with NTT primes
        # and non-zero elements required for inverse (like 1/n).
        # For CRT, M_i % mod_i != 0 should hold because NTT moduli are primes.
        raise ValueError("Modular inverse does not exist for a divisible by m")
    return fast_pow(a, m - 2, m)

# Number Theoretic Transform (NTT)
def ntt(a, mod, root, inverse=False):
    """Performs Number Theoretic Transform (NTT) or its inverse.
    Modifies the input list 'a' in-place.
    'mod' must be prime.
    'root' must be a primitive root modulo 'mod'.
    '(mod - 1)' must be divisible by n = len(a), and n must be a power of 2.
    """
    n = len(a)
    if n == 1:
        return a

    # Bit-reversal permutation - ensures elements are in the correct order for iterative FFT
    # Precompute or compute on the fly; this is standard iterative method
    rev = [0] * n
    for i in range(n):
        # Calculate reverse bit order index
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (n >> 1))
        # Swap elements if the current index is less than the reversed index
        if i < rev[i]:
            a[i], a[rev[i]] = a[rev[i]], a[i]

    # Cooley-Tukey butterfly - iterative version
    len_ = 2
    while len_ <= n:
        # Calculate the principal n/len_-th root of unity for this level
        wlen = fast_pow(root, (mod - 1) // len_, mod)
        if inverse:
            # For inverse NTT, use the inverse of the root
            wlen = mod_inverse(wlen, mod)

        half_len = len_ >> 1
        # Iterate through blocks of size len_
        for i in range(0, n, len_):
            w = 1 # Current power of the principal root for this block
            # Iterate through the first half of the block
            for j in range(half_len):
                u = a[i + j]
                v = (w * a[i + j + half_len]) % mod
                # Butterfly operation
                a[i + j] = (u + v) % mod
                # Ensure the result of subtraction is non-negative before modulo
                a[i + j + half_len] = (u - v + mod) % mod
                # Update the power of the root for the next element
                w = (w * wlen) % mod
        len_ <<= 1 # Double the block size for the next level

    # Scale by 1/n for inverse NTT
    if inverse:
        n_inv = mod_inverse(n, mod)
        for i in range(n):
            a[i] = (a[i] * n_inv) % mod

    # The function modifies 'a' in-place, but returning it is conventional
    return a


# Main function to solve the problem
def solve():
    P = 258280327 # The target modulus given in the problem

    # Choose NTT-friendly primes and their primitive roots.
    # Their product should be larger than the maximum possible coefficient
    # of the product polynomial (F mod P) * (G mod P), which is roughly (N+M+1)*P^2.
    # With N, M <= 2e5 and P ~ 2.6e8, max coeff ~ 4e5 * (2.6e8)^2 ~ 2.7e22.
    # The product of three primes (~10^9 each) is ~10^27, which is sufficient.
    MOD1 = 998244353   # Prime, P-1 = 2^23 * 7 * 17
    ROOT1 = 3          # Primitive root for MOD1
    MOD2 = 1004535809  # Prime, P-1 = 2^21 * 479
    ROOT2 = 3          # Primitive root for MOD2
    MOD3 = 1012924417  # Prime, P-1 = 2^24 * 483
    ROOT3 = 5          # Primitive root for MOD3

    MODS = [MOD1, MOD2, MOD3]
    ROOTS = [ROOT1, ROOT2, ROOT3]

    # Read input polynomials F(x) and G(x) using fast I/O
    N = int(sys.stdin.readline())
    F_coeffs = list(map(int, sys.stdin.readline().split()))
    M = int(sys.stdin.readline())
    G_coeffs = list(map(int, sys.stdin.readline().split()))

    # Handle the edge case where one or both polynomials are the zero polynomial.
    # The problem defines the degree of the zero polynomial as 0.
    # A polynomial is zero if its degree is 0 and its only coefficient is 0.
    is_F_zero = (N == 0 and len(F_coeffs) == 1 and F_coeffs[0] == 0)
    is_G_zero = (M == 0 and len(G_coeffs) == 1 and G_coeffs[0] == 0)

    if is_F_zero or is_G_zero:
        print(0) # Degree of the resulting zero polynomial
        print(0) # Coefficients list is just [0]
        return

    # Reduce input coefficients modulo P. We want to compute (F(x) * G(x)) mod P.
    # This is equivalent to computing (F(x) mod P) * (G(x) mod P) mod P.
    F_mod_P = [c % P for c in F_coeffs]
    G_mod_P = [c % P for c in G_coeffs]

    # The degree of the product polynomial H(x) = F(x) * G(x) is L = N + M.
    L = N + M

    # Determine the size for NTT: smallest power of 2 that is >= L + 1.
    # The result polynomial has L+1 coefficients (from degree 0 to L).
    n = 1
    while n <= L:
        n <<= 1

    # Store the results of polynomial multiplication modulo each NTT prime
    results_mod_pi = []

    # Perform NTT-based polynomial multiplication for each modulus
    for i in range(3):
        mod = MODS[i]
        root = ROOTS[i]

        # Prepare coefficient lists modulo the current NTT prime 'mod'.
        # Start with coefficients already reduced modulo P.
        F_pi = list(F_mod_P)
        G_pi = list(G_mod_P)

        # Reduce coefficients further modulo the current NTT prime 'mod'.
        # This is only necessary if mod < P, but good practice.
        F_pi = [c % mod for c in F_pi]
        G_pi = [c % mod for c in G_pi]

        # Pad coefficient lists with zeros to length n for NTT.
        F_pi.extend([0] * (n - len(F_pi)))
        G_pi.extend([0] * (n - len(G_pi)))

        # Perform NTT. Pass copies because ntt modifies lists in-place.
        ntt_F = ntt(list(F_pi), mod, root, inverse=False)
        ntt_G = ntt(list(G_pi), mod, root, inverse=False)

        # Pointwise multiplication in the NTT domain (convolution theorem)
        ntt_H = [(ntt_F[k] * ntt_G[k]) % mod for k in range(n)]

        # Perform inverse NTT to get the coefficients of the product polynomial H(x) modulo 'mod'.
        # ntt function modifies ntt_H in-place.
        H_pi = ntt(ntt_H, mod, root, inverse=True)

        # Store the resulting coefficients (we only need up to degree L).
        results_mod_pi.append(H_pi[:L+1])

    # Combine the results using the Chinese Remainder Theorem (CRT)
    # We have H(x) mod MOD1, H(x) mod MOD2, H(x) mod MOD3.
    # We want to find H(x) mod P.
    # First, find H(x) mod (MOD1*MOD2*MOD3) using CRT.
    H_final = [0] * (L + 1) # Stores the final coefficients modulo P

    # Precompute CRT constants
    M_prod = MOD1 * MOD2 * MOD3 # Product of NTT moduli
    M1 = M_prod // MOD1
    M2 = M_prod // MOD2
    M3 = M_prod // MOD3

    # Compute modular inverses required for CRT. Check for errors.
    try:
        y1 = mod_inverse(M1, MOD1)
        y2 = mod_inverse(M2, MOD2)
        y3 = mod_inverse(M3, MOD3)
    except ValueError:
        # This error should not occur if MOD1, MOD2, MOD3 are distinct primes.
        print("Error: Modular inverse failed in CRT setup.", file=sys.stderr)
        return

    # Precompute terms for the CRT formula: Ci = (M/Mi) * mod_inverse(M/Mi, Mi)
    C1 = M1 * y1
    C2 = M2 * y2
    C3 = M3 * y3

    # Apply CRT for each coefficient k from 0 to L
    for k in range(L + 1):
        # Get the k-th coefficient modulo each NTT prime
        r1 = results_mod_pi[0][k]
        r2 = results_mod_pi[1][k]
        r3 = results_mod_pi[2][k]

        # Apply CRT formula: H_k = sum(ri * Ci) mod M_prod
        # Python handles large integers required for intermediate calculations.
        H_k_mod_Mprod = (r1 * C1 + r2 * C2 + r3 * C3) % M_prod

        # The final coefficient is H_k mod P.
        H_final[k] = H_k_mod_Mprod % P

    # Determine the actual degree of the resulting polynomial H(x) modulo P.
    # This is the index of the highest non-zero coefficient.
    L_actual = L
    while L_actual >= 0 and H_final[L_actual] == 0:
        L_actual -= 1

    # Handle the case where the product polynomial is the zero polynomial modulo P.
    if L_actual < 0:
        # All coefficients are zero modulo P.
        print(0) # Degree is 0
        print(0) # Coefficients list is [0]
    else:
        # Print the actual degree L_actual
        print(L_actual)
        # Print the coefficients from degree 0 up to L_actual, space-separated.
        print(*(H_final[:L_actual + 1]))

# Execute the solver function
solve()
0