結果

問題 No.3117 Reversible Tile
ユーザー Mistletoe
提出日時 2025-04-19 19:58:31
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
AC  
実行時間 962 ms / 3,000 ms
コード長 2,004 bytes
コンパイル時間 401 ms
コンパイル使用メモリ 12,544 KB
実行使用メモリ 11,648 KB
最終ジャッジ日時 2025-04-19 19:58:42
合計ジャッジ時間 9,808 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

#!/usr/bin/env python3
import sys
input = sys.stdin.readline
MOD = 998244353

def main():
    N, M = map(int, input().split())
    A = list(map(int, input().split()))
    # Build the difference‐parity vector B of length N+1
    B = [0]*(N+1)
    B[0] = A[0]
    for i in range(1, N):
        B[i] = A[i] ^ A[i-1]
    B[N] = A[N-1]
    k = sum(B)             # weight of B

    # Precompute factorials and inverses up to N+1
    fac = [1]*(N+2)
    for i in range(1, N+2):
        fac[i] = fac[i-1]*i % MOD
    invfac = [1]*(N+2)
    invfac[N+1] = pow(fac[N+1], MOD-2, MOD)
    for i in range(N, -1, -1):
        invfac[i] = invfac[i+1]*(i+1) % MOD

    # Build C1[j] = (-1)^j * C(k, j)  for j=0..k
    C1 = [0]*(k+1)
    for j in range(k+1):
        # C(k,j)
        c = fac[k]*invfac[j] % MOD * invfac[k-j] % MOD
        if j & 1:
            c = (MOD - c)
        C1[j] = c

    # Build C2[s] = C(N+1-k, s)  for s=0..N+1-k
    n2 = (N+1) - k
    C2 = [0]*(n2+1)
    for s in range(n2+1):
        C2[s] = fac[n2]*invfac[s] % MOD * invfac[n2-s] % MOD

    # Convolution P = C1 * C2  =>  P[a] for a=0..N+1
    # We do it in O(k * n2)
    P = [0]*(N+2)
    for j in range(k+1):
        v1 = C1[j]
        # unroll inner loop partially? fine as is
        for s in range(n2+1):
            P[j+s] += v1 * C2[s]
    # reduce mod
    for a in range(N+2):
        P[a] %= MOD

    # Precompute T and the S(a) = sum_{i<j}(-1)^{X_i+X_j} for |X|=a
    T = N*(N+1)//2 % MOD
    S = [0]*(N+2)
    nn = N+1
    for a in range(N+2):
        # b = nn - a
        # S(a) = [#equal pairs] - [#unequal pairs]
        #      = (C(a,2)+C(b,2)) - a*b = T - 2ab  mod
        ab2 = (2 * a * (nn - a)) % MOD
        S[a] = (T - ab2) % MOD

    # Sum up P[a] * S[a]^M
    total = 0
    for a in range(N+2):
        total = (total + P[a] * pow(S[a], M, MOD)) % MOD

    # Multiply by 2^{-(N+1)}  mod
    inv2 = (MOD + 1)//2
    total = total * pow(inv2, N+1, MOD) % MOD

    print(total)

if __name__ == "__main__":
    main()
0