結果
問題 |
No.1696 Nonnil
|
ユーザー |
![]() |
提出日時 | 2025-05-14 13:06:58 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 7,961 bytes |
コンパイル時間 | 147 ms |
コンパイル使用メモリ | 82,608 KB |
実行使用メモリ | 417,032 KB |
最終ジャッジ日時 | 2025-05-14 13:08:41 |
合計ジャッジ時間 | 5,953 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 5 TLE * 1 -- * 33 |
ソースコード
import sys # Set higher recursion depth if needed, although this code is iterative. # sys.setrecursionlimit(2000) MOD = 998244353 def fast_pow(base, power): """ Computes (base^power) % MOD efficiently using modular exponentiation. Handles potentially large power N. """ result = 1 # Ensure base is within [0, MOD-1] base %= MOD while power > 0: if power % 2 == 1: result = (result * base) % MOD base = (base * base) % MOD power //= 2 return result def solve(): # Read input: N, K, M, and M intervals N, K = map(int, sys.stdin.readline().split()) M = int(sys.stdin.readline()) intervals = [] if M > 0: for _ in range(M): intervals.append(list(map(int, sys.stdin.readline().split()))) # Base case: If there are no constraints (M=0), any sequence using elements from {1, ..., K} is valid. # The total number of such sequences is K^N. if M == 0: print(fast_pow(K, N)) return # Identify all unique endpoints defining elementary intervals. # Include 1 and K+1 as boundaries. endpoints = {1, K + 1} for l, r in intervals: endpoints.add(l) # Add r+1 as an endpoint. Ensure it doesn't go excessively beyond K+1. # Problem constraints say 1 <= L <= R <= K, so r+1 <= K+1. endpoints.add(r + 1) # Sort the unique endpoints. sorted_endpoints = sorted(list(endpoints)) # Define elementary intervals E_k = [p_k, p_{k+1}-1]. # Store them as [start, end, length]. Only include intervals within [1, K]. elem_intervals = [] for i in range(len(sorted_endpoints) - 1): p_start = sorted_endpoints[i] p_end = sorted_endpoints[i+1] - 1 # Only consider intervals that start within [1, K] and are non-empty. if p_start <= K and p_start <= p_end: # The interval might extend beyond K, clip it to K. effective_end = min(p_end, K) # Check if the clipped interval is valid. if p_start <= effective_end: elem_intervals.append([p_start, effective_end, effective_end - p_start + 1]) # T_prime is the number of elementary intervals. T_prime = len(elem_intervals) # For each elementary interval E_k, compute C_k: the set of constraint indices i such that E_k is fully contained within interval i. # Store C_k as a list of constraint indices. C = [[] for _ in range(T_prime)] for k in range(T_prime): ek_start, ek_end, _ = elem_intervals[k] for i in range(M): li, ri = intervals[i] # Check containment condition: L_i <= start(E_k) AND R_i >= end(E_k) if li <= ek_start and ri >= ek_end: C[k].append(i) # Identify the distinct non-empty sets among C_k. These are denoted as S_j. # Use a dictionary S_map to map the tuple representation of a set C_k to its unique index j. S_map = {} # S_list stores the actual constraint index lists for each distinct set S_j. S_list = [] for k in range(T_prime): # Convert list C[k] to a sorted tuple to use as a hashable dictionary key. Ck_tuple = tuple(sorted(C[k])) # Only consider non-empty sets C_k. if len(Ck_tuple) > 0: if Ck_tuple not in S_map: # Assign a new index j if this set C_k is encountered for the first time. new_idx = len(S_list) S_map[Ck_tuple] = new_idx # Store the list of constraint indices for this S_j set. S_list.append(list(Ck_tuple)) # p is the number of distinct non-empty sets S_j. p = len(S_list) # Calculate W_j: the total length of elementary intervals E_k whose corresponding set C_k is exactly S_j. # W[0] stores W_0 (total length for C_k = empty set). # W[j+1] stores W_{j+1} (total length for C_k = S_{j+1}). Note 1-based indexing for S_j mapping. W = [0] * (p + 1) for k in range(T_prime): _, _, wk = elem_intervals[k] # wk is the length of E_k. Ck_tuple = tuple(sorted(C[k])) if len(Ck_tuple) == 0: # If C_k is empty, contribute to W_0. W[0] = (W[0] + wk) else: # If C_k is non-empty, find its index j using S_map. j = S_map[Ck_tuple] # Contribute length wk to W_{j+1}. W[j+1] = (W[j+1] + wk) # Partition constraints {0, ..., M-1} into equivalence classes E_r. # Two constraints i, i' are equivalent if they belong to the same set of S_j's. # The signature of a constraint i is the set {j+1 | i \in S_j}. We use 1-based index for j. # Precompute for each constraint i, the list of indices j+1 such that i is in S_j. i_to_Sj_indices = [[] for _ in range(M)] for j in range(p): for i in S_list[j]: # For each constraint index i listed in S_j i_to_Sj_indices[i].append(j + 1) # Add 1-based index j+1 to i's signature list. # Map signature tuple -> count of constraints with this signature (m_r). constraint_signatures = {} for i in range(M): # The signature list is precomputed. Sort it and convert to tuple. signature = tuple(sorted(i_to_Sj_indices[i])) # Increment the count for this signature. if signature not in constraint_signatures: constraint_signatures[signature] = 0 constraint_signatures[signature] += 1 # equiv_classes stores pairs (P_r_tuple, m_r). # P_r_tuple is the signature tuple {j+1 | i \in S_j for any i in class E_r}. # m_r is the count of constraints in class E_r. equiv_classes = list(constraint_signatures.items()) # q is the number of distinct equivalence classes. q = len(equiv_classes) # Final calculation using the derived formula based on inclusion-exclusion over equivalence classes. # Sum = Sum_{R subset {1..q}} (-1)^|R| * (V_R)^N mod MOD # V_R = W_0 + Sum_{j=1..p} W_j * Indicator[forall r in R, j+1 not in P_r] total_sum = 0 # Iterate through all 2^q subsets R of equivalence classes indices {0..q-1}. for i in range(1 << q): # R_indices holds the indices {0..q-1} of classes included in the current subset R. R_indices = [] for k in range(q): if (i >> k) & 1: R_indices.append(k) # Calculate V_R = W_0 + sum_{j=1..p} W_j * Indicator[...] current_sum_W = W[0] # Start with W_0 # Iterate over S_j sets represented by indices j=1..p. for j in range(1, p + 1): keep_Wj = True # Assume W_j should be included by default. # Check the condition: for all r in R, j must not be in P_r. for r_idx in R_indices: Pr_tuple, _ = equiv_classes[r_idx] # Get signature tuple P_r for class r_idx. # Check efficiently if j (the 1-based index) is present in the sorted tuple P_r. # Using simple 'in' check on tuple. Binary search could be used if needed for optimization. if j in Pr_tuple: # Condition violated: j is in P_r for some r in R. keep_Wj = False break # No need to check other r's for this j. if keep_Wj: # If condition holds for all r in R, add W_j to the sum. current_sum_W = (current_sum_W + W[j]) # Calculate term (V_R)^N mod MOD. Apply modulo to base first. term = fast_pow(current_sum_W, N) # Apply the inclusion-exclusion sign (-1)^|R|. if len(R_indices) % 2 == 1: # If size of R is odd total_sum = (total_sum - term + MOD) % MOD # Subtract term else: # If size of R is even total_sum = (total_sum + term) % MOD # Add term # Print the final result. print(total_sum) solve()