結果
問題 |
No.1193 Penguin Sequence
|
ユーザー |
![]() |
提出日時 | 2025-04-09 20:56:08 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 699 ms / 2,000 ms |
コード長 | 3,543 bytes |
コンパイル時間 | 372 ms |
コンパイル使用メモリ | 82,208 KB |
実行使用メモリ | 167,368 KB |
最終ジャッジ日時 | 2025-04-09 20:56:56 |
合計ジャッジ時間 | 19,860 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 38 |
ソースコード
import bisect MOD = 998244353 def main(): import sys n, *rest = list(map(int, sys.stdin.read().split())) A = rest[:n] # Compute cnt_all_pairs and cnt_inversion_pairs using Fenwick Tree # Coordinate compression for all elements sorted_unique = sorted(list(set(A))) rank = {v:i for i, v in enumerate(sorted_unique)} m = len(sorted_unique) freq = [0] * m for v in A: idx = bisect.bisect_left(sorted_unique, v) freq[idx] += 1 # Compute cnt_all_pairs class FenwickTree: def __init__(self, size): self.n = size self.tree = [0] * (self.n + 1) def update(self, idx, delta): while idx <= self.n: self.tree[idx] += delta idx += idx & -idx def query(self, idx): res = 0 while idx > 0: res += self.tree[idx] idx -= idx & -idx return res def range_query(self, l, r): return self.query(r) - self.query(l - 1) ft_all = FenwickTree(m) for i in range(m): ft_all.update(i + 1, freq[i]) # 1-based indexing cnt_all_pairs = 0 for v in A: idx = bisect.bisect_right(sorted_unique, v) if idx < m: count = ft_all.range_query(idx + 1, m) # 1-based to m cnt_all_pairs += count cnt_all_pairs %= MOD # Compute cnt_inversion_pairs (normal inversion count) # Coordinate compression for the original array for inversion count sorted_A = sorted(list(set(A))) rank_inv = {v:i+1 for i, v in enumerate(sorted_A)} # 1-based max_rank = len(sorted_A) ft_inv = FenwickTree(max_rank) inv_count = 0 for j in reversed(range(n)): v_rank = bisect.bisect_left(sorted_A, A[j]) v_rank += 1 # 1-based inv_count += ft_inv.query(v_rank - 1) ft_inv.update(v_rank, 1) inv_count %= MOD # Precompute factorials and inverse factorials max_fact = n fact, inv_fact = [1]*(max_fact + 1), [1]*(max_fact + 1) for i in range(1, max_fact + 1): fact[i] = fact[i-1] * i % MOD inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n_, k_): if k_ < 0 or k_ > n_: return 0 return fact[n_] * inv_fact[k_] % MOD * inv_fact[n_ - k_] % MOD # Compute T = product of C(n,i) for i in 1..n T = 1 for i in range(1, n+1): c = comb(n, i) T = T * c % MOD # Compute S = sum_{k<l} k*l = (s1^2 - s2) / 2 s1 = n * (n + 1) // 2 s1_mod = s1 % MOD s2 = n * (n + 1) * (2 * n + 1) // 6 s2_mod = s2 % MOD S_mod = (pow(s1_mod, 2, MOD) - s2_mod) * pow(2, MOD-2, MOD) % MOD # Compute sum_k_part = sum_{k=2 to n} k(k-1) = (s2 - s1) % MOD sum_k_part = (s2 - s1) % MOD # Compute denominators and their inverses n_mod = n % MOD inv_n = pow(n_mod, MOD-2, MOD) inv_n_square = pow(n_mod * n_mod % MOD, MOD-2, MOD) if n >= 2: inv_n_times_n_minus_1 = pow(n_mod * ( (n-1) % MOD ) % MOD, MOD-2, MOD) else: inv_n_times_n_minus_1 = 0 # but for n<2, sum_k_part is 0, so term2 is 0 term1 = cnt_all_pairs * S_mod % MOD term1 = term1 * inv_n_square % MOD term2 = inv_count * sum_k_part % MOD term2 = term2 * inv_n_times_n_minus_1 % MOD ans = (term1 + term2) % MOD ans = ans * T % MOD print(ans) if __name__ == '__main__': main()