結果

問題 No.1762 🐙🐄🌲
ユーザー lam6er
提出日時 2025-04-09 20:56:10
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,162 ms / 4,000 ms
コード長 4,376 bytes
コンパイル時間 199 ms
コンパイル使用メモリ 82,716 KB
実行使用メモリ 265,760 KB
最終ジャッジ日時 2025-04-09 20:58:16
合計ジャッジ時間 18,315 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 47
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    N, P = map(int, sys.stdin.readline().split())
    
    # Precompute factorial and inverse factorial up to needed values
    max_fact = 5 * 10**5 * 3  # 3C can be up to ~3* (5e5/4) ~ 375e3
    max_needed = max(3*((5*10**5)//4), (5*10**5) *3, (5*10**5)*7)
    max_fact = max(max_fact, max_needed)
    fact = [1] * (max_fact + 1)
    for i in range(1, max_fact + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1]*(max_fact +1)
    inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
    for i in range(max_fact-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    # Check validity
    if (N-1) %4 !=0:
        print(0)
        return
    C = (N-1)//4
    O = N - C
    if O <0 or P > O:
        print(0)
        return
    K = C -1 -7*P
    m = O - P
    if K <0 or K >6*m or m <0:
        print(0)
        return
    
    # Compute combinations: C(n, C) and C(O, P)
    def comb(n, k):
        if k <0 or k >n:
            return 0
        return fact[n] * inv_fact[k] % MOD * inv_fact[n -k] % MOD
    c_n_c = comb(N, C)
    c_o_p = comb(O, P)
    ans = c_n_c * c_o_p % MOD
    
    # Compute (3C)! / (3!^C)
    term3C = 1
    term3C = term3C * fact[3*C] % MOD
    inv6 = pow(6, MOD-2, MOD)
    inv6_C = pow(inv6, C, MOD)
    term3C = term3C * inv6_C % MOD
    ans = ans * term3C % MOD
    
    # Compute (C-1)! / 7!^P
    if C-1 <0:
        print(0)
        return
    termC1 = fact[C-1] if C-1 >=0 else 1
    inv7f = pow(5040, MOD-2, MOD)
    inv7f_P = pow(inv7f, P, MOD)
    termC1 = termC1 * inv7f_P % MOD
    ans = ans * termC1 % MOD
    
    # Compute [x^K] (sum_{s=0}^6 x^s /s! )^m
    # Implement NTT-based multiplication
    # Define NTT functions
    def ntt(a, inverse=False):
        # Cooley-Tukey FFT algorithm
        n = len(a)
        log_n = (n).bit_length() -1
        rev = [0]*n
        for i in range(n):
            rev[i] = rev[i >>1] >>1
            if i &1:
                rev[i] |= n >>1
            if i < rev[i]:
                a[i], a[rev[i]] = a[rev[i]], a[i]
        root = pow(3, (MOD-1)//n, MOD) if not inverse else pow(3, MOD-1 - (MOD-1)//n, MOD)
        roots = [1]*(n//2)
        for i in range(1, len(roots)):
            roots[i] = roots[i-1] * root % MOD
        current_length = 1
        while current_length < n:
            for i in range(0, n, 2*current_length):
                for j in range(current_length):
                    idx_e = i + j
                    idx_o = i + j + current_length
                    even = a[idx_e]
                    odd = a[idx_o] * roots[j * (n//(2*current_length))] % MOD
                    a[idx_e] = (even + odd) % MOD
                    a[idx_o] = (even - odd) % MOD
                    if a[idx_o] <0:
                        a[idx_o] += MOD
            current_length *=2
        if inverse:
            inv_n = pow(n, MOD-2, MOD)
            for i in range(n):
                a[i] = a[i] * inv_n % MOD
        return a
    
    def multiply_ntt(a, b, K):
        # compute a * b mod x^(K+1)
        len_a = len(a)
        len_b = len(b)
        if len_a ==0 or len_b==0:
            return []
        new_len = len_a + len_b -1
        n = 1
        while n < new_len:
            n <<=1
        a_ntt = a + [0]*(n - len_a)
        b_ntt = b + [0]*(n - len_b)
        a_ntt = ntt(a_ntt)
        b_ntt = ntt(b_ntt)
        c_ntt = [(x*y) % MOD for x, y in zip(a_ntt, b_ntt)]
        c_ntt = ntt(c_ntt, inverse=True)
        res = [c_ntt[i] for i in range(min(new_len, K+1))]
        return res
    
    # Function to compute poly^exp mod x^(K+1)
    def poly_pow(poly, exp, K):
        result = [1]
        while exp >0:
            if exp %2 ==1:
                result = multiply_ntt(result, poly, K)
            poly = multiply_ntt(poly, poly, K)
            exp //=2
        return result
    
    # Generate f(x) = sum_{s=0}^6 x^s/s!
    f = [0]*(7)
    for s in range(7):
        f[s] = inv_fact[s]
    # Compute f(x)^m mod x^{K+1}
    # Handle m=0 case
    if m ==0:
        if K ==0:
            coeff = 1
        else:
            coeff =0
    else:
        poly = f[:7]
        res_poly = poly_pow(poly, m, K)
        if K < len(res_poly):
            coeff = res_poly[K]
        else:
            coeff =0
    
    ans = ans * coeff % MOD
    print(ans)
    
if __name__ == '__main__':
    main()
0