結果

問題 No.1696 Nonnil
ユーザー qwewe
提出日時 2025-05-14 13:06:58
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 7,961 bytes
コンパイル時間 147 ms
コンパイル使用メモリ 82,608 KB
実行使用メモリ 417,032 KB
最終ジャッジ日時 2025-05-14 13:08:41
合計ジャッジ時間 5,953 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 5 TLE * 1 -- * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

# Set higher recursion depth if needed, although this code is iterative.
# sys.setrecursionlimit(2000) 

MOD = 998244353

def fast_pow(base, power):
    """
    Computes (base^power) % MOD efficiently using modular exponentiation.
    Handles potentially large power N.
    """
    result = 1
    # Ensure base is within [0, MOD-1]
    base %= MOD 
    while power > 0:
        if power % 2 == 1:
            result = (result * base) % MOD
        base = (base * base) % MOD
        power //= 2
    return result

def solve():
    # Read input: N, K, M, and M intervals
    N, K = map(int, sys.stdin.readline().split())
    M = int(sys.stdin.readline())
    
    intervals = []
    if M > 0:
        for _ in range(M):
            intervals.append(list(map(int, sys.stdin.readline().split())))

    # Base case: If there are no constraints (M=0), any sequence using elements from {1, ..., K} is valid.
    # The total number of such sequences is K^N.
    if M == 0:
        print(fast_pow(K, N))
        return

    # Identify all unique endpoints defining elementary intervals.
    # Include 1 and K+1 as boundaries.
    endpoints = {1, K + 1}
    for l, r in intervals:
        endpoints.add(l)
        # Add r+1 as an endpoint. Ensure it doesn't go excessively beyond K+1.
        # Problem constraints say 1 <= L <= R <= K, so r+1 <= K+1.
        endpoints.add(r + 1)

    # Sort the unique endpoints.
    sorted_endpoints = sorted(list(endpoints))
    
    # Define elementary intervals E_k = [p_k, p_{k+1}-1].
    # Store them as [start, end, length]. Only include intervals within [1, K].
    elem_intervals = []
    for i in range(len(sorted_endpoints) - 1):
        p_start = sorted_endpoints[i]
        p_end = sorted_endpoints[i+1] - 1
        
        # Only consider intervals that start within [1, K] and are non-empty.
        if p_start <= K and p_start <= p_end:
            # The interval might extend beyond K, clip it to K.
            effective_end = min(p_end, K)
            # Check if the clipped interval is valid.
            if p_start <= effective_end:
                 elem_intervals.append([p_start, effective_end, effective_end - p_start + 1])

    # T_prime is the number of elementary intervals.
    T_prime = len(elem_intervals) 

    # For each elementary interval E_k, compute C_k: the set of constraint indices i such that E_k is fully contained within interval i.
    # Store C_k as a list of constraint indices.
    C = [[] for _ in range(T_prime)]
    for k in range(T_prime):
        ek_start, ek_end, _ = elem_intervals[k]
        for i in range(M):
            li, ri = intervals[i]
            # Check containment condition: L_i <= start(E_k) AND R_i >= end(E_k)
            if li <= ek_start and ri >= ek_end:
                C[k].append(i)

    # Identify the distinct non-empty sets among C_k. These are denoted as S_j.
    # Use a dictionary S_map to map the tuple representation of a set C_k to its unique index j.
    S_map = {}
    # S_list stores the actual constraint index lists for each distinct set S_j.
    S_list = [] 

    for k in range(T_prime):
         # Convert list C[k] to a sorted tuple to use as a hashable dictionary key.
         Ck_tuple = tuple(sorted(C[k]))
         # Only consider non-empty sets C_k.
         if len(Ck_tuple) > 0:
             if Ck_tuple not in S_map:
                 # Assign a new index j if this set C_k is encountered for the first time.
                 new_idx = len(S_list)
                 S_map[Ck_tuple] = new_idx
                 # Store the list of constraint indices for this S_j set.
                 S_list.append(list(Ck_tuple)) 

    # p is the number of distinct non-empty sets S_j.
    p = len(S_list) 

    # Calculate W_j: the total length of elementary intervals E_k whose corresponding set C_k is exactly S_j.
    # W[0] stores W_0 (total length for C_k = empty set).
    # W[j+1] stores W_{j+1} (total length for C_k = S_{j+1}). Note 1-based indexing for S_j mapping.
    W = [0] * (p + 1) 
    
    for k in range(T_prime):
        _, _, wk = elem_intervals[k] # wk is the length of E_k.
        Ck_tuple = tuple(sorted(C[k]))
        if len(Ck_tuple) == 0:
            # If C_k is empty, contribute to W_0.
            W[0] = (W[0] + wk)
        else:
            # If C_k is non-empty, find its index j using S_map.
            j = S_map[Ck_tuple]
            # Contribute length wk to W_{j+1}.
            W[j+1] = (W[j+1] + wk)

    # Partition constraints {0, ..., M-1} into equivalence classes E_r.
    # Two constraints i, i' are equivalent if they belong to the same set of S_j's.
    # The signature of a constraint i is the set {j+1 | i \in S_j}. We use 1-based index for j.
    
    # Precompute for each constraint i, the list of indices j+1 such that i is in S_j.
    i_to_Sj_indices = [[] for _ in range(M)]
    for j in range(p):
      for i in S_list[j]: # For each constraint index i listed in S_j
          i_to_Sj_indices[i].append(j + 1) # Add 1-based index j+1 to i's signature list.
    
    # Map signature tuple -> count of constraints with this signature (m_r).
    constraint_signatures = {}
    for i in range(M):
        # The signature list is precomputed. Sort it and convert to tuple.
        signature = tuple(sorted(i_to_Sj_indices[i]))
        # Increment the count for this signature.
        if signature not in constraint_signatures:
            constraint_signatures[signature] = 0 
        constraint_signatures[signature] += 1

    # equiv_classes stores pairs (P_r_tuple, m_r).
    # P_r_tuple is the signature tuple {j+1 | i \in S_j for any i in class E_r}.
    # m_r is the count of constraints in class E_r.
    equiv_classes = list(constraint_signatures.items())

    # q is the number of distinct equivalence classes.
    q = len(equiv_classes)

    # Final calculation using the derived formula based on inclusion-exclusion over equivalence classes.
    # Sum = Sum_{R subset {1..q}} (-1)^|R| * (V_R)^N mod MOD
    # V_R = W_0 + Sum_{j=1..p} W_j * Indicator[forall r in R, j+1 not in P_r]
    total_sum = 0
    
    # Iterate through all 2^q subsets R of equivalence classes indices {0..q-1}.
    for i in range(1 << q):
        # R_indices holds the indices {0..q-1} of classes included in the current subset R.
        R_indices = [] 
        for k in range(q):
            if (i >> k) & 1:
                R_indices.append(k)
        
        # Calculate V_R = W_0 + sum_{j=1..p} W_j * Indicator[...]
        current_sum_W = W[0] # Start with W_0
        
        # Iterate over S_j sets represented by indices j=1..p.
        for j in range(1, p + 1): 
            keep_Wj = True # Assume W_j should be included by default.
            # Check the condition: for all r in R, j must not be in P_r.
            for r_idx in R_indices:
                 Pr_tuple, _ = equiv_classes[r_idx] # Get signature tuple P_r for class r_idx.
                 # Check efficiently if j (the 1-based index) is present in the sorted tuple P_r.
                 # Using simple 'in' check on tuple. Binary search could be used if needed for optimization.
                 if j in Pr_tuple: 
                     # Condition violated: j is in P_r for some r in R.
                     keep_Wj = False 
                     break # No need to check other r's for this j.
            
            if keep_Wj:
                # If condition holds for all r in R, add W_j to the sum.
                current_sum_W = (current_sum_W + W[j])

        # Calculate term (V_R)^N mod MOD. Apply modulo to base first.
        term = fast_pow(current_sum_W, N) 
        
        # Apply the inclusion-exclusion sign (-1)^|R|.
        if len(R_indices) % 2 == 1: # If size of R is odd
            total_sum = (total_sum - term + MOD) % MOD # Subtract term
        else: # If size of R is even
            total_sum = (total_sum + term) % MOD # Add term

    # Print the final result.
    print(total_sum)

solve()
0