結果

問題 No.1753 Don't cheat.
ユーザー zkouzkou
提出日時 2021-06-19 18:41:50
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,361 ms / 3,000 ms
コード長 3,051 bytes
コンパイル時間 538 ms
コンパイル使用メモリ 86,832 KB
実行使用メモリ 94,568 KB
最終ジャッジ日時 2023-08-30 06:20:03
合計ジャッジ時間 31,230 ms
ジャッジサーバーID
(参考情報)
judge15 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 68 ms
71,272 KB
testcase_01 AC 72 ms
71,480 KB
testcase_02 AC 67 ms
71,272 KB
testcase_03 AC 70 ms
71,276 KB
testcase_04 AC 68 ms
71,168 KB
testcase_05 AC 68 ms
71,544 KB
testcase_06 AC 70 ms
71,520 KB
testcase_07 AC 1,256 ms
93,752 KB
testcase_08 AC 1,308 ms
94,348 KB
testcase_09 AC 739 ms
86,928 KB
testcase_10 AC 1,295 ms
93,384 KB
testcase_11 AC 779 ms
86,404 KB
testcase_12 AC 1,226 ms
92,772 KB
testcase_13 AC 896 ms
88,752 KB
testcase_14 AC 1,122 ms
91,244 KB
testcase_15 AC 1,237 ms
92,652 KB
testcase_16 AC 935 ms
89,260 KB
testcase_17 AC 1,342 ms
94,352 KB
testcase_18 AC 1,204 ms
92,184 KB
testcase_19 AC 977 ms
89,512 KB
testcase_20 AC 1,254 ms
93,596 KB
testcase_21 AC 889 ms
88,568 KB
testcase_22 AC 1,150 ms
91,832 KB
testcase_23 AC 789 ms
87,208 KB
testcase_24 AC 1,283 ms
93,832 KB
testcase_25 AC 1,017 ms
89,868 KB
testcase_26 AC 1,361 ms
94,472 KB
testcase_27 AC 1,347 ms
94,528 KB
testcase_28 AC 1,338 ms
94,364 KB
testcase_29 AC 1,345 ms
94,372 KB
testcase_30 AC 1,338 ms
94,528 KB
testcase_31 AC 1,342 ms
94,568 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353
half = pow(2, MOD - 2, MOD)


def int2frac(n, m=None, N = 10000, D = 10000):
    """
    Return (r, s) s.t. r = s * n (mod m) if such a pair exists.
    Otherwise, return (0, 0).

    Parameters
    ----------
    n: an integer that will be represented as a fraction.   
    m: A modulus used to represent n.     
    N: An upperbound of r, which is the numerator of n. 
    D: An upperbound of s, which is the denominator of n.    
    """
    def gcd(a, b):
        while b:
            a, b = b, a % b
        return a
    if m is None:
        m = MOD
    v = (m, 0)
    w = (n, 1)
    while w[0] > N:
        q = v[0] // w[0]
        v, w = w, (v[0] - q * w[0], v[1] - q * w[1])
    if w[1] < 0:
        w = (-w[0], -w[1])
    if w[1] <= D and gcd(w[0], w[1]) == 1:
        return w
    else:
        return (0, 0)


def fwht(a) -> None:
    """
    In-place Fast Walsh–Hadamard Transform of array a.
    Reference: https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform
    """
    h = 1
    while h < len(a):
        for i in range(0, len(a), h * 2):
            for j in range(i, i + h):
                x = a[j]
                y = a[j + h]
                a[j] = (x + y) % MOD
                a[j + h] = (x - y) % MOD
        h *= 2

def ifwht(a) -> None:
    """
    In-place Inverse Fast Walsh–Hadamard Transform of array a.
    Reference: https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform
    """
    h = 1
    while h < len(a):
        for i in range(0, len(a), h * 2):
            for j in range(i, i + h):
                x = a[j]
                y = a[j + h]
                a[j] = (x + y) % MOD
                a[j + h] = (x - y) % MOD
        h *= 2
    inv_h = pow(h, MOD - 2, MOD)
    for i in range(len(a)):
        a[i] = a[i] * inv_h % MOD


N = int(input())
As = list(map(int, input().split()))

# assert 1 <= N <= 2 * 10 ** 3
# assert all(0 <= A <= 10 ** 5 for A in As)
# assert N + 1 == len(As)
# assert As[0]

sumAs = sum(As)
invsumAs = pow(sum(As), MOD - 2, MOD)
for i in range(N + 1):
    As[i] *= invsumAs
    As[i] %= MOD

# print([int2frac(A) for A in As])

z = 1 << N.bit_length()

Hadamard = [[MOD - 1 if bin(i & j).count('1') % 2 == 1 else 1 for j in range(z)] for i in range(N + 1)]

fwhts = [[A * h % MOD for h in row] for A, row in zip(As, Hadamard)]

e0 = Hadamard[0]

As_fwht = As + [0] * (z - N - 1)
fwht(As_fwht)


q = [pow(1 - As_fwht[k] + fwhts[0][k], MOD - 2, MOD) * As[0] % MOD * e0[k] % MOD for k in range(z)]
# assert all((1 - As_fwht[k] + fwhts[0][k]) % MOD != 0 for k in range(z))
ifwht(q)

psum = [0] * z
for x in range(1, N + 1):
    p = [pow(1 - As_fwht[k] + fwhts[x][k] + fwhts[0][k], MOD - 2, MOD) * As[x] % MOD * e0[k] % MOD for k in range(z)]
    # assert all((1 - As_fwht[k] + fwhts[x][k] + fwhts[0][k]) % MOD != 0 for k in range(z))
    for i, e in enumerate(p):
        psum[i] += e
        psum[i] %= MOD

ifwht(psum)

answer = q[0]
for xor in range(z):
    answer += psum[xor] * q[xor] % MOD
    answer %= MOD

print((1 - answer) % MOD) 
0