結果
問題 |
No.2327 Inversion Sum
|
ユーザー |
![]() |
提出日時 | 2025-03-20 20:26:52 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 263 ms / 2,000 ms |
コード長 | 3,952 bytes |
コンパイル時間 | 142 ms |
コンパイル使用メモリ | 82,244 KB |
実行使用メモリ | 113,648 KB |
最終ジャッジ日時 | 2025-03-20 20:28:25 |
合計ジャッジ時間 | 5,174 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 30 |
ソースコード
MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() idx = 0 N = int(data[idx]) idx += 1 M = int(data[idx]) idx += 1 pos_P = {} K_list = [] fixed_P = set() for _ in range(M): P = int(data[idx]) idx += 1 K = int(data[idx]) idx += 1 pos_P[P] = K K_list.append(K) fixed_P.add(P) # Collect free elements F and sort them free_F = [] for p in range(1, N+1): if p not in fixed_P: free_F.append(p) free_F.sort() L = len(free_F) # Precompute factorials up to N max_fact = max(N, L) fact = [1] * (max_fact + 1) for i in range(2, max_fact+1): fact[i] = fact[i-1] * i % MOD # Step 1: Contribution from fixed pairs fixed_pairs_count = 0 if M > 0: # Process fixed elements sorted by their positions sorted_fixed = sorted(pos_P.items(), key=lambda x: x[1]) fixed_values = [p for p, _ in sorted_fixed] # Use Fenwick Tree to count inversions in fixed_values class FenwickTree: def __init__(self, size): self.size = size self.tree = [0]*(size+2) def update(self, idx, delta=1): while idx <= self.size: self.tree[idx] += delta idx += idx & -idx def query(self, idx): res = 0 while idx > 0: res += self.tree[idx] idx -= idx & -idx return res ft = FenwickTree(N) inv_count = 0 # Iterate from right to left for i in reversed(range(len(fixed_values))): p = fixed_values[i] inv_count += ft.query(p-1) ft.update(p) inv_count %= MOD fixed_pairs_count = inv_count % MOD # Contribution from fixed pairs if L == 0: s_fixed = 1 # fact[0] = 1 else: s_fixed = fact[L] % MOD contrib_fixed = fixed_pairs_count * s_fixed % MOD # Step 2: Contribution from free pairs inv_2 = (MOD + 1) // 2 # inverse of 2 mod MOD if L >= 2: c_free = L * (L-1) // 2 % MOD contrib_free = c_free * fact[L] % MOD contrib_free = contrib_free * inv_2 % MOD else: contrib_free = 0 # Step 3: Contribution between fixed and free elements contrib_fixed_free = 0 if M > 0 and L > 0: inv_L = pow(L, MOD-2, MOD) # Pre-sort K_list for faster queries K_sorted = sorted(K_list) for P_u, K_u in pos_P.items(): # Compute Sum1: number of elements in free_F > P_u # Use bisect to find the first element > P_u import bisect idx = bisect.bisect_right(free_F, P_u) Sum1 = len(free_F) - idx Sum2 = (L - Sum1) % MOD # Compute c and d # c is the number of free positions < K_u # a1: K-1 is the total positions before K_u a1 = K_u -1 # number of constrained positions <= K_u-1 num_constrained_left = bisect.bisect_left(K_sorted, K_u) c = (a1 - num_constrained_left) % MOD a2 = N - K_u # number of constrained positions > K_u (since K_sorted is sorted) num_constrained_right = len(K_sorted) - bisect.bisect_right(K_sorted, K_u) d = (a2 - num_constrained_right) % MOD # Compute term term = (Sum1 * c + Sum2 * d) % MOD term = term * s_fixed % MOD term = term * inv_L % MOD contrib_fixed_free = (contrib_fixed_free + term) % MOD # Total contribution total = (contrib_fixed + contrib_free + contrib_fixed_free) % MOD print(total) if __name__ == '__main__': main()