結果
問題 | No.1753 Don't cheat. |
ユーザー | zkou |
提出日時 | 2021-06-19 18:41:50 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,239 ms / 3,000 ms |
コード長 | 3,051 bytes |
コンパイル時間 | 256 ms |
コンパイル使用メモリ | 82,304 KB |
実行使用メモリ | 93,424 KB |
最終ジャッジ日時 | 2024-06-10 06:49:50 |
合計ジャッジ時間 | 27,544 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 37 ms
52,864 KB |
testcase_01 | AC | 37 ms
52,608 KB |
testcase_02 | AC | 36 ms
53,248 KB |
testcase_03 | AC | 36 ms
52,736 KB |
testcase_04 | AC | 36 ms
52,736 KB |
testcase_05 | AC | 37 ms
52,992 KB |
testcase_06 | AC | 38 ms
52,864 KB |
testcase_07 | AC | 1,149 ms
91,588 KB |
testcase_08 | AC | 1,187 ms
92,704 KB |
testcase_09 | AC | 683 ms
84,864 KB |
testcase_10 | AC | 1,179 ms
92,292 KB |
testcase_11 | AC | 702 ms
85,780 KB |
testcase_12 | AC | 1,111 ms
91,532 KB |
testcase_13 | AC | 805 ms
87,264 KB |
testcase_14 | AC | 1,014 ms
90,364 KB |
testcase_15 | AC | 1,111 ms
91,652 KB |
testcase_16 | AC | 837 ms
88,064 KB |
testcase_17 | AC | 1,208 ms
93,064 KB |
testcase_18 | AC | 1,071 ms
91,164 KB |
testcase_19 | AC | 872 ms
88,636 KB |
testcase_20 | AC | 1,127 ms
92,324 KB |
testcase_21 | AC | 789 ms
87,404 KB |
testcase_22 | AC | 1,033 ms
90,376 KB |
testcase_23 | AC | 697 ms
85,912 KB |
testcase_24 | AC | 1,154 ms
92,680 KB |
testcase_25 | AC | 919 ms
89,472 KB |
testcase_26 | AC | 1,231 ms
92,940 KB |
testcase_27 | AC | 1,239 ms
93,424 KB |
testcase_28 | AC | 1,193 ms
93,184 KB |
testcase_29 | AC | 1,204 ms
92,972 KB |
testcase_30 | AC | 1,194 ms
93,312 KB |
testcase_31 | AC | 1,199 ms
92,952 KB |
ソースコード
MOD = 998244353 half = pow(2, MOD - 2, MOD) def int2frac(n, m=None, N = 10000, D = 10000): """ Return (r, s) s.t. r = s * n (mod m) if such a pair exists. Otherwise, return (0, 0). Parameters ---------- n: an integer that will be represented as a fraction. m: A modulus used to represent n. N: An upperbound of r, which is the numerator of n. D: An upperbound of s, which is the denominator of n. """ def gcd(a, b): while b: a, b = b, a % b return a if m is None: m = MOD v = (m, 0) w = (n, 1) while w[0] > N: q = v[0] // w[0] v, w = w, (v[0] - q * w[0], v[1] - q * w[1]) if w[1] < 0: w = (-w[0], -w[1]) if w[1] <= D and gcd(w[0], w[1]) == 1: return w else: return (0, 0) def fwht(a) -> None: """ In-place Fast Walsh–Hadamard Transform of array a. Reference: https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform """ h = 1 while h < len(a): for i in range(0, len(a), h * 2): for j in range(i, i + h): x = a[j] y = a[j + h] a[j] = (x + y) % MOD a[j + h] = (x - y) % MOD h *= 2 def ifwht(a) -> None: """ In-place Inverse Fast Walsh–Hadamard Transform of array a. Reference: https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform """ h = 1 while h < len(a): for i in range(0, len(a), h * 2): for j in range(i, i + h): x = a[j] y = a[j + h] a[j] = (x + y) % MOD a[j + h] = (x - y) % MOD h *= 2 inv_h = pow(h, MOD - 2, MOD) for i in range(len(a)): a[i] = a[i] * inv_h % MOD N = int(input()) As = list(map(int, input().split())) # assert 1 <= N <= 2 * 10 ** 3 # assert all(0 <= A <= 10 ** 5 for A in As) # assert N + 1 == len(As) # assert As[0] sumAs = sum(As) invsumAs = pow(sum(As), MOD - 2, MOD) for i in range(N + 1): As[i] *= invsumAs As[i] %= MOD # print([int2frac(A) for A in As]) z = 1 << N.bit_length() Hadamard = [[MOD - 1 if bin(i & j).count('1') % 2 == 1 else 1 for j in range(z)] for i in range(N + 1)] fwhts = [[A * h % MOD for h in row] for A, row in zip(As, Hadamard)] e0 = Hadamard[0] As_fwht = As + [0] * (z - N - 1) fwht(As_fwht) q = [pow(1 - As_fwht[k] + fwhts[0][k], MOD - 2, MOD) * As[0] % MOD * e0[k] % MOD for k in range(z)] # assert all((1 - As_fwht[k] + fwhts[0][k]) % MOD != 0 for k in range(z)) ifwht(q) psum = [0] * z for x in range(1, N + 1): p = [pow(1 - As_fwht[k] + fwhts[x][k] + fwhts[0][k], MOD - 2, MOD) * As[x] % MOD * e0[k] % MOD for k in range(z)] # assert all((1 - As_fwht[k] + fwhts[x][k] + fwhts[0][k]) % MOD != 0 for k in range(z)) for i, e in enumerate(p): psum[i] += e psum[i] %= MOD ifwht(psum) answer = q[0] for xor in range(z): answer += psum[xor] * q[xor] % MOD answer %= MOD print((1 - answer) % MOD)