結果
| 問題 | No.2613 Sum of Combination |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2024-01-19 17:12:22 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 2,451 ms / 4,500 ms |
| コード長 | 2,932 bytes |
| 記録 | |
| コンパイル時間 | 220 ms |
| コンパイル使用メモリ | 82,176 KB |
| 実行使用メモリ | 100,224 KB |
| 最終ジャッジ日時 | 2024-09-28 03:48:50 |
| 合計ジャッジ時間 | 48,445 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 49 |
ソースコード
MOD = 998244353
def find_primitive_root(n):
phi = n - 1
factors = []
for i in range(2, n):
if i * i > phi:
break
if phi % i == 0:
factors.append(i)
while phi % i == 0:
phi //= i
if phi > 1:
factors.append(phi)
for res in range(1, n):
ok = True
for factor in factors:
if pow(res, (n - 1) // factor, n) == 1:
ok = False
break
if ok:
return res
return -1
primitive_root = find_primitive_root(MOD)
def ntt(a):
n = len(a)
m = n
while m > 1:
mh = m >> 1
wm = pow(primitive_root, (MOD - 1) // m, MOD)
w = 1
for i in range(mh):
for j in range(i, n, m):
k = j + mh
a0 = a[j]
a1 = a[k]
a[j] = a0 + a1
if a[j] >= MOD:
a[j] -= MOD
a[k] = (a0 - a1 + MOD) * w % MOD
w = w * wm % MOD
m = mh
def intt(a):
n = len(a)
m = 2
while m <= n:
mh = m >> 1
wm = pow(primitive_root, MOD - 1 - (MOD - 1) // m, MOD)
w = 1
for i in range(mh):
for j in range(i, n, m):
k = j + mh
a0 = a[j]
a1 = a[k] * w % MOD
a[j] = a0 + a1
if a[j] >= MOD:
a[j] -= MOD
a[k] = a0 - a1
if a[k] < 0:
a[k] += MOD
w = w * wm % MOD
m <<= 1
inv = pow(n, MOD - 2, MOD)
for i in range(n):
a[i] = a[i] * inv % MOD
def solve():
def comb(a, b):
return (index_table_sum[a] - index_table_sum[b] - index_table_sum[a - b] + 2 * (p - 1)) % (p - 1)
n, p = map(int, input().split())
q = find_primitive_root(p)
index_table = [0] * p
k = 1
for i in range(p - 1):
index_table[k] = i
k = k * q % p
index_table_sum = [0] * p
for i in range(1, p):
index_table_sum[i] = index_table_sum[i - 1] + index_table[i]
if index_table_sum[i] >= p - 1:
index_table_sum[i] -= p - 1
len = 1
while len < (p - 1) * 2:
len *= 2
count = [0] * len
count[0] = 1
while n > 0:
m = n % p
n //= p
a = [0] * len
for i in range(m + 1):
a[comb(m, i)] += 1
ntt(count)
ntt(a)
for i in range(len):
count[i] = count[i] * a[i] % MOD
intt(count)
for i in range(len - 1, p - 2, -1):
count[i - p + 1] += count[i]
if count[i - p + 1] >= MOD:
count[i - p + 1] -= MOD
count[i] = 0
ans = 0
k = 1
for i in range(p - 1):
ans = (ans + count[i] * k) % MOD
k = k * q % p
print(ans)
solve()