結果

問題 No.2959 Dolls' Tea Party
ユーザー lam6er
提出日時 2025-03-26 16:00:05
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,811 bytes
コンパイル時間 252 ms
コンパイル使用メモリ 82,576 KB
実行使用メモリ 146,612 KB
最終ジャッジ日時 2025-03-26 16:01:14
合計ジャッジ時間 6,806 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 5 TLE * 1 -- * 27
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

import sys
MOD = 998244353
def main():
input = sys.stdin.read().split()
ptr = 0
N, K = int(input[ptr]), int(input[ptr+1])
ptr +=2
A = list(map(int, input[ptr:ptr+N]))
# Precompute factorial and inverse factorial up to s_max + K_max
s_max = K
fact = [1]*(s_max + 1)
for i in range(1, s_max+1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1]*(s_max +1)
inv_fact[s_max] = pow(fact[s_max], MOD-2, MOD)
for i in range(s_max-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
# Precompute divisors of K
def get_divisors(n):
divisors = set()
for i in range(1, int(n**0.5)+1):
if n %i ==0:
divisors.add(i)
divisors.add(n//i)
return sorted(divisors)
divisors = get_divisors(K)
divisors.sort()
# Precompute phi for each divisor
phi = {}
for d in divisors:
m = d
res = m
for i in range(2, int(m**0.5)+1):
if i*i >m:
break
if m %i ==0:
res = res //i * (i-1)
while m %i ==0:
m = m//i
if m>1:
res = res //m * (m-1)
phi[d] = res
# Precompute frequency array for A_i < K
K_floor = K
freq = [0]*(K_floor)
C = 0
for a in A:
if a >= K:
C +=1
else:
freq[a] +=1
# Precompute for each divisor d
total =0
for d in divisors:
s = K // d
if s ==0:
continue
# Compute m_i for all i where A_i <K
sum_floor_high =0
sum_floor_low =0
for x in range(K_floor):
sum_floor_low += freq[x] * (x //d)
# sum_floor_high is sum over A_i >=K of floor(A_i/d)
sum_floor_high =0
for a in A:
if a >=K:
sum_floor_high += (a //d)
total_B = sum_floor_low + sum_floor_high
if total_B < s:
continue
# Now compute the number of sequences of length s, with each type count <= m_i
# m_i is floor(A_i/d)
# Split into two parts: m_i >=s and m_i <s
# Part 1: m_i >=s (i.e., A_i >=d*s = K)
# C is the number of such types
# Part 2: m_i <s (A_i < K)
# Compute their generating functions
# Part 1: (sum_{k=0}^s x^k /k! )^C
# Part 2: product_{i: m_i <s} (sum_{k=0}^m_i x^k/k! )
dp = [0]*(s+1)
dp[0] =1
# Handle part 1: (sum_{k=0}^s x^k/k! )^C
# We need to raise the polynomial to the Cth power
# First, compute the polynomial (sum x^k/k! for k=0 to s)
poly_part1 = [0]*(s+1)
for k in range(s+1):
poly_part1[k] = inv_fact[k]
# Compute poly_part1^C using exponentiation by squaring
def poly_pow(p, exponent, max_degree):
result = [0]*(max_degree+1)
result[0] =1
current = p.copy()
while exponent >0:
if exponent %2 ==1:
# Multiply result by current
new_result = [0]*(max_degree+1)
for i in range(max_degree+1):
if result[i] ==0:
continue
for j in range(max_degree+1 -i):
new_result[i+j] = (new_result[i+j] + result[i] * current[j]) % MOD
result = new_result
# Square current
new_current = [0]*(max_degree*2 +1)
for i in range(max_degree+1):
if current[i] ==0:
continue
for j in range(max_degree+1 -i):
new_current[i+j] = (new_current[i+j] + current[i] * current[j]) % MOD
current = new_current[:max_degree+1]
exponent = exponent //2
return result
part1 = poly_pow(poly_part1, C, s)
# Now handle part2: types with A_i < K, m_i = floor(A_i/d)
part2 = [0]*(s+1)
part2[0] =1
# Precompute m_i for all i where A_i <K
# Group by m_i
groups = {}
for x in range(K_floor):
cnt = freq[x]
if cnt ==0:
continue
m_i = x //d
if m_i >=s:
continue # already handled in part1
if m_i not in groups:
groups[m_i] =0
groups[m_i] += cnt
# Process each group in part2
for t in groups:
cnt = groups[t]
if cnt ==0:
continue
# The polynomial for this group is (sum_{k=0}^t x^k/k! )^cnt
poly = [inv_fact[k] if k <=t else 0 for k in range(s+1)]
poly = poly_pow(poly, cnt, s)
# Multiply into part2
new_part2 = [0]*(s+1)
for i in range(s+1):
if part2[i] ==0:
continue
for j in range(s+1 -i):
new_part2[i+j] = (new_part2[i+j] + part2[i] * poly[j]) % MOD
part2 = new_part2
# Combine part1 and part2
combined = [0]*(s+1)
for i in range(s+1):
if part1[i] ==0:
continue
for j in range(s+1 -i):
combined[i+j] = (combined[i+j] + part1[i] * part2[j]) % MOD
# Get the coefficient of x^s
coeff = combined[s]
f_d = coeff * fact[s] % MOD
# Add to total
total = (total + phi[d] * f_d) % MOD
# Compute answer
answer = total * pow(K, MOD-2, MOD) % MOD
print(answer)
if __name__ == '__main__':
main()
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0