結果
問題 |
No.2959 Dolls' Tea Party
|
ユーザー |
![]() |
提出日時 | 2025-06-12 21:18:55 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,663 bytes |
コンパイル時間 | 282 ms |
コンパイル使用メモリ | 81,912 KB |
実行使用メモリ | 257,400 KB |
最終ジャッジ日時 | 2025-06-12 21:19:14 |
合計ジャッジ時間 | 7,515 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 5 TLE * 1 -- * 27 |
ソースコード
import sys MOD = 998244353 def main(): N, K = map(int, sys.stdin.readline().split()) A = list(map(int, sys.stdin.readline().split())) # Precompute factorial and inverse factorial up to 1300 max_m = 1300 fact = [1] * (max_m + 1) for i in range(1, max_m + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_m + 1) inv_fact[max_m] = pow(fact[max_m], MOD-2, MOD) for i in range(max_m-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD # Function to multiply two polynomials up to degree m def multiply_poly(a, b, m): res = [0]*(m+1) for i in range(len(a)): if a[i] == 0: continue for j in range(len(b)): if i + j > m: break res[i+j] = (res[i+j] + a[i] * b[j]) % MOD return res # Function to compute poly^e using exponentiation by squaring, up to degree m def pow_poly(poly, e, m): result = [1] # Identity element current = poly.copy() while e > 0: if e % 2 == 1: result = multiply_poly(result, current, m) current = multiply_poly(current, current, m) e = e // 2 return result # Get all divisors of K def get_divisors(n): divisors = [] for i in range(1, int(n**0.5)+1): if n % i == 0: divisors.append(i) if i != n // i: divisors.append(n//i) divisors.sort() return divisors divisors = get_divisors(K) phi = {} for d in divisors: # Compute Euler's totient function for d temp = d prime_factors = {} x = d for i in range(2, int(x**0.5)+1): if x % i == 0: prime_factors[i] = 0 while x % i == 0: prime_factors[i] += 1 x = x // i if x > 1: prime_factors[x] = 1 phi_val = 1 for p, cnt in prime_factors.items(): phi_val *= (p**(cnt) - p**(cnt-1)) phi[d] = phi_val inv_K = pow(K, MOD-2, MOD) total_ans = 0 for d in divisors: m = K // d if m > max_m: continue # since max_m is 1300 and K <=1300, m can't exceed 1300 c_list = [] sum_c = 0 for a in A: c = a // d c = min(c, m) c_list.append(c) sum_c += c if sum_c < m: continue t_high = 0 list_low = [] for c in c_list: if c >= m: t_high += 1 else: list_low.append(c) # Compute P_high = (sum_{k=0}^m x^k/k! )^t_high if t_high == 0: P_high = [1] + [0]*m else: base = [0]*(m+1) for k in range(m+1): base[k] = inv_fact[k] P_high = pow_poly(base, t_high, m) # Compute P_low P_low = [1] + [0]*m for c in list_low: current_poly = [0]*(c+1) for k in range(c+1): current_poly[k] = inv_fact[k] P_low = multiply_poly(P_low, current_poly, m) # Multiply P_high and P_low total = multiply_poly(P_high, P_low, m) if len(total) <= m: coeff = 0 else: coeff = total[m] f_d = coeff * fact[m] % MOD total_ans = (total_ans + phi[d] * f_d) % MOD ans = total_ans * inv_K % MOD print(ans) if __name__ == '__main__': main()