結果
問題 |
No.1762 🐙🐄🌲
|
ユーザー |
![]() |
提出日時 | 2025-04-09 20:56:10 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 2,162 ms / 4,000 ms |
コード長 | 4,376 bytes |
コンパイル時間 | 199 ms |
コンパイル使用メモリ | 82,716 KB |
実行使用メモリ | 265,760 KB |
最終ジャッジ日時 | 2025-04-09 20:58:16 |
合計ジャッジ時間 | 18,315 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 47 |
ソースコード
MOD = 998244353 def main(): import sys N, P = map(int, sys.stdin.readline().split()) # Precompute factorial and inverse factorial up to needed values max_fact = 5 * 10**5 * 3 # 3C can be up to ~3* (5e5/4) ~ 375e3 max_needed = max(3*((5*10**5)//4), (5*10**5) *3, (5*10**5)*7) max_fact = max(max_fact, max_needed) fact = [1] * (max_fact + 1) for i in range(1, max_fact + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1]*(max_fact +1) inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD # Check validity if (N-1) %4 !=0: print(0) return C = (N-1)//4 O = N - C if O <0 or P > O: print(0) return K = C -1 -7*P m = O - P if K <0 or K >6*m or m <0: print(0) return # Compute combinations: C(n, C) and C(O, P) def comb(n, k): if k <0 or k >n: return 0 return fact[n] * inv_fact[k] % MOD * inv_fact[n -k] % MOD c_n_c = comb(N, C) c_o_p = comb(O, P) ans = c_n_c * c_o_p % MOD # Compute (3C)! / (3!^C) term3C = 1 term3C = term3C * fact[3*C] % MOD inv6 = pow(6, MOD-2, MOD) inv6_C = pow(inv6, C, MOD) term3C = term3C * inv6_C % MOD ans = ans * term3C % MOD # Compute (C-1)! / 7!^P if C-1 <0: print(0) return termC1 = fact[C-1] if C-1 >=0 else 1 inv7f = pow(5040, MOD-2, MOD) inv7f_P = pow(inv7f, P, MOD) termC1 = termC1 * inv7f_P % MOD ans = ans * termC1 % MOD # Compute [x^K] (sum_{s=0}^6 x^s /s! )^m # Implement NTT-based multiplication # Define NTT functions def ntt(a, inverse=False): # Cooley-Tukey FFT algorithm n = len(a) log_n = (n).bit_length() -1 rev = [0]*n for i in range(n): rev[i] = rev[i >>1] >>1 if i &1: rev[i] |= n >>1 if i < rev[i]: a[i], a[rev[i]] = a[rev[i]], a[i] root = pow(3, (MOD-1)//n, MOD) if not inverse else pow(3, MOD-1 - (MOD-1)//n, MOD) roots = [1]*(n//2) for i in range(1, len(roots)): roots[i] = roots[i-1] * root % MOD current_length = 1 while current_length < n: for i in range(0, n, 2*current_length): for j in range(current_length): idx_e = i + j idx_o = i + j + current_length even = a[idx_e] odd = a[idx_o] * roots[j * (n//(2*current_length))] % MOD a[idx_e] = (even + odd) % MOD a[idx_o] = (even - odd) % MOD if a[idx_o] <0: a[idx_o] += MOD current_length *=2 if inverse: inv_n = pow(n, MOD-2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def multiply_ntt(a, b, K): # compute a * b mod x^(K+1) len_a = len(a) len_b = len(b) if len_a ==0 or len_b==0: return [] new_len = len_a + len_b -1 n = 1 while n < new_len: n <<=1 a_ntt = a + [0]*(n - len_a) b_ntt = b + [0]*(n - len_b) a_ntt = ntt(a_ntt) b_ntt = ntt(b_ntt) c_ntt = [(x*y) % MOD for x, y in zip(a_ntt, b_ntt)] c_ntt = ntt(c_ntt, inverse=True) res = [c_ntt[i] for i in range(min(new_len, K+1))] return res # Function to compute poly^exp mod x^(K+1) def poly_pow(poly, exp, K): result = [1] while exp >0: if exp %2 ==1: result = multiply_ntt(result, poly, K) poly = multiply_ntt(poly, poly, K) exp //=2 return result # Generate f(x) = sum_{s=0}^6 x^s/s! f = [0]*(7) for s in range(7): f[s] = inv_fact[s] # Compute f(x)^m mod x^{K+1} # Handle m=0 case if m ==0: if K ==0: coeff = 1 else: coeff =0 else: poly = f[:7] res_poly = poly_pow(poly, m, K) if K < len(res_poly): coeff = res_poly[K] else: coeff =0 ans = ans * coeff % MOD print(ans) if __name__ == '__main__': main()