結果
| 問題 | No.3450 Permutation of Even Scores |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-03-11 11:57:43 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 5,735 bytes |
| 記録 | |
| コンパイル時間 | 277 ms |
| コンパイル使用メモリ | 85,300 KB |
| 実行使用メモリ | 189,872 KB |
| 最終ジャッジ日時 | 2026-03-11 11:59:38 |
| 合計ジャッジ時間 | 87,299 ms |
|
ジャッジサーバーID (参考情報) |
judge2_0 / judge1_0 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 11 TLE * 35 |
ソースコード
import sys
MOD = 998244353
PRIMITIVE_ROOT = 3
INV2 = (MOD + 1) // 2
# ------------------------------------------------------------
# 基本ユーティリティ
# ------------------------------------------------------------
def build_factorials(n: int) -> list[int]:
"""
fact[i] = i! mod MOD を返す
"""
fact = [1] * (n + 1)
for i in range(1, n + 1):
fact[i] = fact[i - 1] * i % MOD
return fact
# ------------------------------------------------------------
# NTT (Number Theoretic Transform)
# ------------------------------------------------------------
def ntt(a: list[int], invert: bool) -> None:
"""
配列 a に対して in-place で NTT / inverse NTT を行う。
長さ len(a) は 2 の冪であることを前提とする。
"""
n = len(a)
# bit-reversal permutation
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j ^= bit
if i < j:
a[i], a[j] = a[j], a[i]
length = 2
while length <= n:
wlen = pow(PRIMITIVE_ROOT, (MOD - 1) // length, MOD)
if invert:
wlen = pow(wlen, MOD - 2, MOD)
half = length >> 1
for start in range(0, n, length):
w = 1
for i in range(start, start + half):
u = a[i]
v = a[i + half] * w % MOD
a[i] = (u + v) % MOD
a[i + half] = (u - v) % MOD
w = w * wlen % MOD
length <<= 1
if invert:
inv_n = pow(n, MOD - 2, MOD)
for i in range(n):
a[i] = a[i] * inv_n % MOD
def convolution(a: list[int], b: list[int]) -> list[int]:
"""
a と b の畳み込みを mod MOD で返す。
短いときは愚直、長いときは NTT を使う。
"""
if not a or not b:
return []
# 可読性優先で閾値は控えめ
if min(len(a), len(b)) <= 60:
res = [0] * (len(a) + len(b) - 1)
for i, x in enumerate(a):
if x == 0:
continue
for j, y in enumerate(b):
res[i + j] = (res[i + j] + x * y) % MOD
return res
need = len(a) + len(b) - 1
n = 1
while n < need:
n <<= 1
fa = a[:] + [0] * (n - len(a))
fb = b[:] + [0] * (n - len(b))
ntt(fa, invert=False)
ntt(fb, invert=False)
for i in range(n):
fa[i] = fa[i] * fb[i] % MOD
ntt(fa, invert=True)
return fa[:need]
# ------------------------------------------------------------
# 本体
# ------------------------------------------------------------
def count_even_score_permutations(n: int, a_values: list[int]) -> int:
"""
問題の答えを返す。
方針:
- signed = Σ_P (-1)^(score(P)) を求める
- 欲しい偶数個数 = (n! + signed) / 2
signed は部分集合 T ⊆ A の包除から
signed = n! + Σ_T≠∅ (-2)^|T| * C(T)
と書ける。
ここで C(T) は「T に対応する条件が全部成り立つ順列数」。
さらに T の最大値を末尾にもつ DP にすると
dp[x] = -2 * ( x! + Σ_{y<x} dp[y] * (x-y+1)! )
が得られる。
この畳み込みを CDQ 分割統治 + NTT で計算する。
"""
fact = build_factorials(n)
# A の要素かどうかを O(1) で判定するための配列
is_target = [False] * (n + 1)
for x in a_values:
is_target[x] = True
# kernel[d] = (d+1)! (d >= 1), kernel[0] = 0
# dp[y] から x への寄与が kernel[x-y] になる
kernel = [0] * n
for d in range(1, n):
kernel[d] = fact[d + 1]
# dp[x]:
# x ∈ A を末尾とする部分集合 T の寄与
# ただし最後の (n-x+1)! はまだ掛けていない
dp = [0] * (n + 1)
# add[x]:
# すでに確定した左側から x へ届く
# Σ dp[y] * (x-y+1)! を溜める配列
add = [0] * (n + 1)
sys.setrecursionlimit(1_000_000)
def cdq(left: int, right: int) -> None:
"""
区間 [left, right] について dp を求める。
先に左半分を確定し、その寄与を右半分へ畳み込みで流す。
"""
if left == right:
if is_target[left]:
# dp[x] = -2 * (x! + add[x])
dp[left] = (-2 * (fact[left] + add[left])) % MOD
return
mid = (left + right) // 2
# まず左半分を完成させる
cdq(left, mid)
# 左半分から右半分へ寄与を流す
left_values = dp[left:mid + 1]
# x - y の最大は right - left
# convolution の添字 x-left に対応するように
# kernel[0..right-left] を使えばよい
kernel_part = kernel[:right - left + 1]
conv = convolution(left_values, kernel_part)
# 実際に必要なのは x ∈ [mid+1, right] の位置
for x in range(mid + 1, right + 1):
add[x] = (add[x] + conv[x - left]) % MOD
# 次に右半分
cdq(mid + 1, right)
cdq(1, n)
# signed = n! + Σ_{x∈A} dp[x] * (n-x+1)!
signed = fact[n]
for x in a_values:
signed = (signed + dp[x] * fact[n - x + 1]) % MOD
# 偶数個数 = (全体 + signed) / 2
answer = (fact[n] + signed) * INV2 % MOD
return answer
# ------------------------------------------------------------
# 入出力
# ------------------------------------------------------------
def main() -> None:
input = sys.stdin.readline
n, m = map(int, input().split())
a_values = list(map(int, input().split()))
print(count_even_score_permutations(n, a_values))
if __name__ == "__main__":
main()