結果
問題 |
No.3117 Reversible Tile
|
ユーザー |
|
提出日時 | 2025-04-19 19:58:31 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
AC
|
実行時間 | 962 ms / 3,000 ms |
コード長 | 2,004 bytes |
コンパイル時間 | 401 ms |
コンパイル使用メモリ | 12,544 KB |
実行使用メモリ | 11,648 KB |
最終ジャッジ日時 | 2025-04-19 19:58:42 |
合計ジャッジ時間 | 9,808 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 24 |
ソースコード
#!/usr/bin/env python3 import sys input = sys.stdin.readline MOD = 998244353 def main(): N, M = map(int, input().split()) A = list(map(int, input().split())) # Build the difference‐parity vector B of length N+1 B = [0]*(N+1) B[0] = A[0] for i in range(1, N): B[i] = A[i] ^ A[i-1] B[N] = A[N-1] k = sum(B) # weight of B # Precompute factorials and inverses up to N+1 fac = [1]*(N+2) for i in range(1, N+2): fac[i] = fac[i-1]*i % MOD invfac = [1]*(N+2) invfac[N+1] = pow(fac[N+1], MOD-2, MOD) for i in range(N, -1, -1): invfac[i] = invfac[i+1]*(i+1) % MOD # Build C1[j] = (-1)^j * C(k, j) for j=0..k C1 = [0]*(k+1) for j in range(k+1): # C(k,j) c = fac[k]*invfac[j] % MOD * invfac[k-j] % MOD if j & 1: c = (MOD - c) C1[j] = c # Build C2[s] = C(N+1-k, s) for s=0..N+1-k n2 = (N+1) - k C2 = [0]*(n2+1) for s in range(n2+1): C2[s] = fac[n2]*invfac[s] % MOD * invfac[n2-s] % MOD # Convolution P = C1 * C2 => P[a] for a=0..N+1 # We do it in O(k * n2) P = [0]*(N+2) for j in range(k+1): v1 = C1[j] # unroll inner loop partially? fine as is for s in range(n2+1): P[j+s] += v1 * C2[s] # reduce mod for a in range(N+2): P[a] %= MOD # Precompute T and the S(a) = sum_{i<j}(-1)^{X_i+X_j} for |X|=a T = N*(N+1)//2 % MOD S = [0]*(N+2) nn = N+1 for a in range(N+2): # b = nn - a # S(a) = [#equal pairs] - [#unequal pairs] # = (C(a,2)+C(b,2)) - a*b = T - 2ab mod ab2 = (2 * a * (nn - a)) % MOD S[a] = (T - ab2) % MOD # Sum up P[a] * S[a]^M total = 0 for a in range(N+2): total = (total + P[a] * pow(S[a], M, MOD)) % MOD # Multiply by 2^{-(N+1)} mod inv2 = (MOD + 1)//2 total = total * pow(inv2, N+1, MOD) % MOD print(total) if __name__ == "__main__": main()