結果

問題 No.2327 Inversion Sum
ユーザー lam6er
提出日時 2025-03-20 20:26:52
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 263 ms / 2,000 ms
コード長 3,952 bytes
コンパイル時間 142 ms
コンパイル使用メモリ 82,244 KB
実行使用メモリ 113,648 KB
最終ジャッジ日時 2025-03-20 20:28:25
合計ジャッジ時間 5,174 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    idx = 0
    N = int(data[idx])
    idx += 1
    M = int(data[idx])
    idx += 1
    
    pos_P = {}
    K_list = []
    fixed_P = set()
    for _ in range(M):
        P = int(data[idx])
        idx += 1
        K = int(data[idx])
        idx += 1
        pos_P[P] = K
        K_list.append(K)
        fixed_P.add(P)
    
    # Collect free elements F and sort them
    free_F = []
    for p in range(1, N+1):
        if p not in fixed_P:
            free_F.append(p)
    free_F.sort()
    L = len(free_F)
    
    # Precompute factorials up to N
    max_fact = max(N, L)
    fact = [1] * (max_fact + 1)
    for i in range(2, max_fact+1):
        fact[i] = fact[i-1] * i % MOD
    
    # Step 1: Contribution from fixed pairs
    fixed_pairs_count = 0
    if M > 0:
        # Process fixed elements sorted by their positions
        sorted_fixed = sorted(pos_P.items(), key=lambda x: x[1])
        fixed_values = [p for p, _ in sorted_fixed]
        # Use Fenwick Tree to count inversions in fixed_values
        class FenwickTree:
            def __init__(self, size):
                self.size = size
                self.tree = [0]*(size+2)
            
            def update(self, idx, delta=1):
                while idx <= self.size:
                    self.tree[idx] += delta
                    idx += idx & -idx
            
            def query(self, idx):
                res = 0
                while idx > 0:
                    res += self.tree[idx]
                    idx -= idx & -idx
                return res
        
        ft = FenwickTree(N)
        inv_count = 0
        # Iterate from right to left
        for i in reversed(range(len(fixed_values))):
            p = fixed_values[i]
            inv_count += ft.query(p-1)
            ft.update(p)
            inv_count %= MOD
        fixed_pairs_count = inv_count % MOD
    
    # Contribution from fixed pairs
    if L == 0:
        s_fixed = 1  # fact[0] = 1
    else:
        s_fixed = fact[L] % MOD
    contrib_fixed = fixed_pairs_count * s_fixed % MOD
    
    # Step 2: Contribution from free pairs
    inv_2 = (MOD + 1) // 2  # inverse of 2 mod MOD
    if L >= 2:
        c_free = L * (L-1) // 2 % MOD
        contrib_free = c_free * fact[L] % MOD
        contrib_free = contrib_free * inv_2 % MOD
    else:
        contrib_free = 0
    
    # Step 3: Contribution between fixed and free elements
    contrib_fixed_free = 0
    if M > 0 and L > 0:
        inv_L = pow(L, MOD-2, MOD)
        # Pre-sort K_list for faster queries
        K_sorted = sorted(K_list)
        for P_u, K_u in pos_P.items():
            # Compute Sum1: number of elements in free_F > P_u
            # Use bisect to find the first element > P_u
            import bisect
            idx = bisect.bisect_right(free_F, P_u)
            Sum1 = len(free_F) - idx
            Sum2 = (L - Sum1) % MOD
            
            # Compute c and d
            # c is the number of free positions < K_u
            # a1: K-1 is the total positions before K_u
            a1 = K_u -1
            # number of constrained positions <= K_u-1
            num_constrained_left = bisect.bisect_left(K_sorted, K_u)
            c = (a1 - num_constrained_left) % MOD
            
            a2 = N - K_u
            # number of constrained positions > K_u (since K_sorted is sorted)
            num_constrained_right = len(K_sorted) - bisect.bisect_right(K_sorted, K_u)
            d = (a2 - num_constrained_right) % MOD
            
            # Compute term
            term = (Sum1 * c + Sum2 * d) % MOD
            term = term * s_fixed % MOD
            term = term * inv_L % MOD
            contrib_fixed_free = (contrib_fixed_free + term) % MOD
    
    # Total contribution
    total = (contrib_fixed + contrib_free + contrib_fixed_free) % MOD
    print(total)

if __name__ == '__main__':
    main()
0