結果
問題 | No.2613 Sum of Combination |
ユーザー | Mitarushi |
提出日時 | 2024-01-19 17:12:22 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 2,451 ms / 4,500 ms |
コード長 | 2,932 bytes |
コンパイル時間 | 220 ms |
コンパイル使用メモリ | 82,176 KB |
実行使用メモリ | 100,224 KB |
最終ジャッジ日時 | 2024-09-28 03:48:50 |
合計ジャッジ時間 | 48,445 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 39 ms
52,352 KB |
testcase_01 | AC | 45 ms
52,352 KB |
testcase_02 | AC | 119 ms
76,672 KB |
testcase_03 | AC | 41 ms
52,608 KB |
testcase_04 | AC | 42 ms
52,096 KB |
testcase_05 | AC | 41 ms
52,736 KB |
testcase_06 | AC | 40 ms
52,736 KB |
testcase_07 | AC | 40 ms
52,480 KB |
testcase_08 | AC | 77 ms
72,064 KB |
testcase_09 | AC | 62 ms
65,408 KB |
testcase_10 | AC | 40 ms
52,736 KB |
testcase_11 | AC | 69 ms
68,608 KB |
testcase_12 | AC | 72 ms
69,248 KB |
testcase_13 | AC | 136 ms
76,800 KB |
testcase_14 | AC | 131 ms
76,800 KB |
testcase_15 | AC | 113 ms
76,300 KB |
testcase_16 | AC | 134 ms
77,056 KB |
testcase_17 | AC | 133 ms
76,672 KB |
testcase_18 | AC | 134 ms
76,928 KB |
testcase_19 | AC | 134 ms
76,928 KB |
testcase_20 | AC | 85 ms
76,416 KB |
testcase_21 | AC | 76 ms
71,680 KB |
testcase_22 | AC | 190 ms
77,696 KB |
testcase_23 | AC | 2,015 ms
99,328 KB |
testcase_24 | AC | 1,911 ms
99,200 KB |
testcase_25 | AC | 1,920 ms
98,688 KB |
testcase_26 | AC | 1,946 ms
99,840 KB |
testcase_27 | AC | 927 ms
88,064 KB |
testcase_28 | AC | 1,936 ms
99,712 KB |
testcase_29 | AC | 1,924 ms
99,328 KB |
testcase_30 | AC | 2,451 ms
100,224 KB |
testcase_31 | AC | 2,043 ms
99,328 KB |
testcase_32 | AC | 1,903 ms
99,200 KB |
testcase_33 | AC | 2,018 ms
99,712 KB |
testcase_34 | AC | 1,863 ms
99,584 KB |
testcase_35 | AC | 1,910 ms
99,584 KB |
testcase_36 | AC | 1,948 ms
100,096 KB |
testcase_37 | AC | 1,955 ms
99,584 KB |
testcase_38 | AC | 1,962 ms
99,584 KB |
testcase_39 | AC | 1,916 ms
99,840 KB |
testcase_40 | AC | 1,920 ms
99,584 KB |
testcase_41 | AC | 1,910 ms
100,224 KB |
testcase_42 | AC | 1,922 ms
100,096 KB |
testcase_43 | AC | 1,943 ms
99,712 KB |
testcase_44 | AC | 1,827 ms
99,840 KB |
testcase_45 | AC | 40 ms
52,992 KB |
testcase_46 | AC | 41 ms
52,736 KB |
testcase_47 | AC | 41 ms
53,120 KB |
testcase_48 | AC | 47 ms
60,800 KB |
testcase_49 | AC | 55 ms
62,848 KB |
testcase_50 | AC | 1,331 ms
99,456 KB |
testcase_51 | AC | 1,377 ms
99,840 KB |
ソースコード
MOD = 998244353 def find_primitive_root(n): phi = n - 1 factors = [] for i in range(2, n): if i * i > phi: break if phi % i == 0: factors.append(i) while phi % i == 0: phi //= i if phi > 1: factors.append(phi) for res in range(1, n): ok = True for factor in factors: if pow(res, (n - 1) // factor, n) == 1: ok = False break if ok: return res return -1 primitive_root = find_primitive_root(MOD) def ntt(a): n = len(a) m = n while m > 1: mh = m >> 1 wm = pow(primitive_root, (MOD - 1) // m, MOD) w = 1 for i in range(mh): for j in range(i, n, m): k = j + mh a0 = a[j] a1 = a[k] a[j] = a0 + a1 if a[j] >= MOD: a[j] -= MOD a[k] = (a0 - a1 + MOD) * w % MOD w = w * wm % MOD m = mh def intt(a): n = len(a) m = 2 while m <= n: mh = m >> 1 wm = pow(primitive_root, MOD - 1 - (MOD - 1) // m, MOD) w = 1 for i in range(mh): for j in range(i, n, m): k = j + mh a0 = a[j] a1 = a[k] * w % MOD a[j] = a0 + a1 if a[j] >= MOD: a[j] -= MOD a[k] = a0 - a1 if a[k] < 0: a[k] += MOD w = w * wm % MOD m <<= 1 inv = pow(n, MOD - 2, MOD) for i in range(n): a[i] = a[i] * inv % MOD def solve(): def comb(a, b): return (index_table_sum[a] - index_table_sum[b] - index_table_sum[a - b] + 2 * (p - 1)) % (p - 1) n, p = map(int, input().split()) q = find_primitive_root(p) index_table = [0] * p k = 1 for i in range(p - 1): index_table[k] = i k = k * q % p index_table_sum = [0] * p for i in range(1, p): index_table_sum[i] = index_table_sum[i - 1] + index_table[i] if index_table_sum[i] >= p - 1: index_table_sum[i] -= p - 1 len = 1 while len < (p - 1) * 2: len *= 2 count = [0] * len count[0] = 1 while n > 0: m = n % p n //= p a = [0] * len for i in range(m + 1): a[comb(m, i)] += 1 ntt(count) ntt(a) for i in range(len): count[i] = count[i] * a[i] % MOD intt(count) for i in range(len - 1, p - 2, -1): count[i - p + 1] += count[i] if count[i - p + 1] >= MOD: count[i - p + 1] -= MOD count[i] = 0 ans = 0 k = 1 for i in range(p - 1): ans = (ans + count[i] * k) % MOD k = k * q % p print(ans) solve()