結果
| 問題 |
No.1193 Penguin Sequence
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-09 20:56:08 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 699 ms / 2,000 ms |
| コード長 | 3,543 bytes |
| コンパイル時間 | 372 ms |
| コンパイル使用メモリ | 82,208 KB |
| 実行使用メモリ | 167,368 KB |
| 最終ジャッジ日時 | 2025-04-09 20:56:56 |
| 合計ジャッジ時間 | 19,860 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 38 |
ソースコード
import bisect
MOD = 998244353
def main():
import sys
n, *rest = list(map(int, sys.stdin.read().split()))
A = rest[:n]
# Compute cnt_all_pairs and cnt_inversion_pairs using Fenwick Tree
# Coordinate compression for all elements
sorted_unique = sorted(list(set(A)))
rank = {v:i for i, v in enumerate(sorted_unique)}
m = len(sorted_unique)
freq = [0] * m
for v in A:
idx = bisect.bisect_left(sorted_unique, v)
freq[idx] += 1
# Compute cnt_all_pairs
class FenwickTree:
def __init__(self, size):
self.n = size
self.tree = [0] * (self.n + 1)
def update(self, idx, delta):
while idx <= self.n:
self.tree[idx] += delta
idx += idx & -idx
def query(self, idx):
res = 0
while idx > 0:
res += self.tree[idx]
idx -= idx & -idx
return res
def range_query(self, l, r):
return self.query(r) - self.query(l - 1)
ft_all = FenwickTree(m)
for i in range(m):
ft_all.update(i + 1, freq[i]) # 1-based indexing
cnt_all_pairs = 0
for v in A:
idx = bisect.bisect_right(sorted_unique, v)
if idx < m:
count = ft_all.range_query(idx + 1, m) # 1-based to m
cnt_all_pairs += count
cnt_all_pairs %= MOD
# Compute cnt_inversion_pairs (normal inversion count)
# Coordinate compression for the original array for inversion count
sorted_A = sorted(list(set(A)))
rank_inv = {v:i+1 for i, v in enumerate(sorted_A)} # 1-based
max_rank = len(sorted_A)
ft_inv = FenwickTree(max_rank)
inv_count = 0
for j in reversed(range(n)):
v_rank = bisect.bisect_left(sorted_A, A[j])
v_rank += 1 # 1-based
inv_count += ft_inv.query(v_rank - 1)
ft_inv.update(v_rank, 1)
inv_count %= MOD
# Precompute factorials and inverse factorials
max_fact = n
fact, inv_fact = [1]*(max_fact + 1), [1]*(max_fact + 1)
for i in range(1, max_fact + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
for i in range(max_fact-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
def comb(n_, k_):
if k_ < 0 or k_ > n_:
return 0
return fact[n_] * inv_fact[k_] % MOD * inv_fact[n_ - k_] % MOD
# Compute T = product of C(n,i) for i in 1..n
T = 1
for i in range(1, n+1):
c = comb(n, i)
T = T * c % MOD
# Compute S = sum_{k<l} k*l = (s1^2 - s2) / 2
s1 = n * (n + 1) // 2
s1_mod = s1 % MOD
s2 = n * (n + 1) * (2 * n + 1) // 6
s2_mod = s2 % MOD
S_mod = (pow(s1_mod, 2, MOD) - s2_mod) * pow(2, MOD-2, MOD) % MOD
# Compute sum_k_part = sum_{k=2 to n} k(k-1) = (s2 - s1) % MOD
sum_k_part = (s2 - s1) % MOD
# Compute denominators and their inverses
n_mod = n % MOD
inv_n = pow(n_mod, MOD-2, MOD)
inv_n_square = pow(n_mod * n_mod % MOD, MOD-2, MOD)
if n >= 2:
inv_n_times_n_minus_1 = pow(n_mod * ( (n-1) % MOD ) % MOD, MOD-2, MOD)
else:
inv_n_times_n_minus_1 = 0 # but for n<2, sum_k_part is 0, so term2 is 0
term1 = cnt_all_pairs * S_mod % MOD
term1 = term1 * inv_n_square % MOD
term2 = inv_count * sum_k_part % MOD
term2 = term2 * inv_n_times_n_minus_1 % MOD
ans = (term1 + term2) % MOD
ans = ans * T % MOD
print(ans)
if __name__ == '__main__':
main()
lam6er