結果
| 問題 |
No.3117 Reversible Tile
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-04-19 19:58:31 |
| 言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
| 結果 |
AC
|
| 実行時間 | 962 ms / 3,000 ms |
| コード長 | 2,004 bytes |
| コンパイル時間 | 401 ms |
| コンパイル使用メモリ | 12,544 KB |
| 実行使用メモリ | 11,648 KB |
| 最終ジャッジ日時 | 2025-04-19 19:58:42 |
| 合計ジャッジ時間 | 9,808 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 24 |
ソースコード
#!/usr/bin/env python3
import sys
input = sys.stdin.readline
MOD = 998244353
def main():
N, M = map(int, input().split())
A = list(map(int, input().split()))
# Build the difference‐parity vector B of length N+1
B = [0]*(N+1)
B[0] = A[0]
for i in range(1, N):
B[i] = A[i] ^ A[i-1]
B[N] = A[N-1]
k = sum(B) # weight of B
# Precompute factorials and inverses up to N+1
fac = [1]*(N+2)
for i in range(1, N+2):
fac[i] = fac[i-1]*i % MOD
invfac = [1]*(N+2)
invfac[N+1] = pow(fac[N+1], MOD-2, MOD)
for i in range(N, -1, -1):
invfac[i] = invfac[i+1]*(i+1) % MOD
# Build C1[j] = (-1)^j * C(k, j) for j=0..k
C1 = [0]*(k+1)
for j in range(k+1):
# C(k,j)
c = fac[k]*invfac[j] % MOD * invfac[k-j] % MOD
if j & 1:
c = (MOD - c)
C1[j] = c
# Build C2[s] = C(N+1-k, s) for s=0..N+1-k
n2 = (N+1) - k
C2 = [0]*(n2+1)
for s in range(n2+1):
C2[s] = fac[n2]*invfac[s] % MOD * invfac[n2-s] % MOD
# Convolution P = C1 * C2 => P[a] for a=0..N+1
# We do it in O(k * n2)
P = [0]*(N+2)
for j in range(k+1):
v1 = C1[j]
# unroll inner loop partially? fine as is
for s in range(n2+1):
P[j+s] += v1 * C2[s]
# reduce mod
for a in range(N+2):
P[a] %= MOD
# Precompute T and the S(a) = sum_{i<j}(-1)^{X_i+X_j} for |X|=a
T = N*(N+1)//2 % MOD
S = [0]*(N+2)
nn = N+1
for a in range(N+2):
# b = nn - a
# S(a) = [#equal pairs] - [#unequal pairs]
# = (C(a,2)+C(b,2)) - a*b = T - 2ab mod
ab2 = (2 * a * (nn - a)) % MOD
S[a] = (T - ab2) % MOD
# Sum up P[a] * S[a]^M
total = 0
for a in range(N+2):
total = (total + P[a] * pow(S[a], M, MOD)) % MOD
# Multiply by 2^{-(N+1)} mod
inv2 = (MOD + 1)//2
total = total * pow(inv2, N+1, MOD) % MOD
print(total)
if __name__ == "__main__":
main()