結果
| 問題 |
No.1696 Nonnil
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 13:06:58 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 7,961 bytes |
| コンパイル時間 | 147 ms |
| コンパイル使用メモリ | 82,608 KB |
| 実行使用メモリ | 417,032 KB |
| 最終ジャッジ日時 | 2025-05-14 13:08:41 |
| 合計ジャッジ時間 | 5,953 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 5 TLE * 1 -- * 33 |
ソースコード
import sys
# Set higher recursion depth if needed, although this code is iterative.
# sys.setrecursionlimit(2000)
MOD = 998244353
def fast_pow(base, power):
"""
Computes (base^power) % MOD efficiently using modular exponentiation.
Handles potentially large power N.
"""
result = 1
# Ensure base is within [0, MOD-1]
base %= MOD
while power > 0:
if power % 2 == 1:
result = (result * base) % MOD
base = (base * base) % MOD
power //= 2
return result
def solve():
# Read input: N, K, M, and M intervals
N, K = map(int, sys.stdin.readline().split())
M = int(sys.stdin.readline())
intervals = []
if M > 0:
for _ in range(M):
intervals.append(list(map(int, sys.stdin.readline().split())))
# Base case: If there are no constraints (M=0), any sequence using elements from {1, ..., K} is valid.
# The total number of such sequences is K^N.
if M == 0:
print(fast_pow(K, N))
return
# Identify all unique endpoints defining elementary intervals.
# Include 1 and K+1 as boundaries.
endpoints = {1, K + 1}
for l, r in intervals:
endpoints.add(l)
# Add r+1 as an endpoint. Ensure it doesn't go excessively beyond K+1.
# Problem constraints say 1 <= L <= R <= K, so r+1 <= K+1.
endpoints.add(r + 1)
# Sort the unique endpoints.
sorted_endpoints = sorted(list(endpoints))
# Define elementary intervals E_k = [p_k, p_{k+1}-1].
# Store them as [start, end, length]. Only include intervals within [1, K].
elem_intervals = []
for i in range(len(sorted_endpoints) - 1):
p_start = sorted_endpoints[i]
p_end = sorted_endpoints[i+1] - 1
# Only consider intervals that start within [1, K] and are non-empty.
if p_start <= K and p_start <= p_end:
# The interval might extend beyond K, clip it to K.
effective_end = min(p_end, K)
# Check if the clipped interval is valid.
if p_start <= effective_end:
elem_intervals.append([p_start, effective_end, effective_end - p_start + 1])
# T_prime is the number of elementary intervals.
T_prime = len(elem_intervals)
# For each elementary interval E_k, compute C_k: the set of constraint indices i such that E_k is fully contained within interval i.
# Store C_k as a list of constraint indices.
C = [[] for _ in range(T_prime)]
for k in range(T_prime):
ek_start, ek_end, _ = elem_intervals[k]
for i in range(M):
li, ri = intervals[i]
# Check containment condition: L_i <= start(E_k) AND R_i >= end(E_k)
if li <= ek_start and ri >= ek_end:
C[k].append(i)
# Identify the distinct non-empty sets among C_k. These are denoted as S_j.
# Use a dictionary S_map to map the tuple representation of a set C_k to its unique index j.
S_map = {}
# S_list stores the actual constraint index lists for each distinct set S_j.
S_list = []
for k in range(T_prime):
# Convert list C[k] to a sorted tuple to use as a hashable dictionary key.
Ck_tuple = tuple(sorted(C[k]))
# Only consider non-empty sets C_k.
if len(Ck_tuple) > 0:
if Ck_tuple not in S_map:
# Assign a new index j if this set C_k is encountered for the first time.
new_idx = len(S_list)
S_map[Ck_tuple] = new_idx
# Store the list of constraint indices for this S_j set.
S_list.append(list(Ck_tuple))
# p is the number of distinct non-empty sets S_j.
p = len(S_list)
# Calculate W_j: the total length of elementary intervals E_k whose corresponding set C_k is exactly S_j.
# W[0] stores W_0 (total length for C_k = empty set).
# W[j+1] stores W_{j+1} (total length for C_k = S_{j+1}). Note 1-based indexing for S_j mapping.
W = [0] * (p + 1)
for k in range(T_prime):
_, _, wk = elem_intervals[k] # wk is the length of E_k.
Ck_tuple = tuple(sorted(C[k]))
if len(Ck_tuple) == 0:
# If C_k is empty, contribute to W_0.
W[0] = (W[0] + wk)
else:
# If C_k is non-empty, find its index j using S_map.
j = S_map[Ck_tuple]
# Contribute length wk to W_{j+1}.
W[j+1] = (W[j+1] + wk)
# Partition constraints {0, ..., M-1} into equivalence classes E_r.
# Two constraints i, i' are equivalent if they belong to the same set of S_j's.
# The signature of a constraint i is the set {j+1 | i \in S_j}. We use 1-based index for j.
# Precompute for each constraint i, the list of indices j+1 such that i is in S_j.
i_to_Sj_indices = [[] for _ in range(M)]
for j in range(p):
for i in S_list[j]: # For each constraint index i listed in S_j
i_to_Sj_indices[i].append(j + 1) # Add 1-based index j+1 to i's signature list.
# Map signature tuple -> count of constraints with this signature (m_r).
constraint_signatures = {}
for i in range(M):
# The signature list is precomputed. Sort it and convert to tuple.
signature = tuple(sorted(i_to_Sj_indices[i]))
# Increment the count for this signature.
if signature not in constraint_signatures:
constraint_signatures[signature] = 0
constraint_signatures[signature] += 1
# equiv_classes stores pairs (P_r_tuple, m_r).
# P_r_tuple is the signature tuple {j+1 | i \in S_j for any i in class E_r}.
# m_r is the count of constraints in class E_r.
equiv_classes = list(constraint_signatures.items())
# q is the number of distinct equivalence classes.
q = len(equiv_classes)
# Final calculation using the derived formula based on inclusion-exclusion over equivalence classes.
# Sum = Sum_{R subset {1..q}} (-1)^|R| * (V_R)^N mod MOD
# V_R = W_0 + Sum_{j=1..p} W_j * Indicator[forall r in R, j+1 not in P_r]
total_sum = 0
# Iterate through all 2^q subsets R of equivalence classes indices {0..q-1}.
for i in range(1 << q):
# R_indices holds the indices {0..q-1} of classes included in the current subset R.
R_indices = []
for k in range(q):
if (i >> k) & 1:
R_indices.append(k)
# Calculate V_R = W_0 + sum_{j=1..p} W_j * Indicator[...]
current_sum_W = W[0] # Start with W_0
# Iterate over S_j sets represented by indices j=1..p.
for j in range(1, p + 1):
keep_Wj = True # Assume W_j should be included by default.
# Check the condition: for all r in R, j must not be in P_r.
for r_idx in R_indices:
Pr_tuple, _ = equiv_classes[r_idx] # Get signature tuple P_r for class r_idx.
# Check efficiently if j (the 1-based index) is present in the sorted tuple P_r.
# Using simple 'in' check on tuple. Binary search could be used if needed for optimization.
if j in Pr_tuple:
# Condition violated: j is in P_r for some r in R.
keep_Wj = False
break # No need to check other r's for this j.
if keep_Wj:
# If condition holds for all r in R, add W_j to the sum.
current_sum_W = (current_sum_W + W[j])
# Calculate term (V_R)^N mod MOD. Apply modulo to base first.
term = fast_pow(current_sum_W, N)
# Apply the inclusion-exclusion sign (-1)^|R|.
if len(R_indices) % 2 == 1: # If size of R is odd
total_sum = (total_sum - term + MOD) % MOD # Subtract term
else: # If size of R is even
total_sum = (total_sum + term) % MOD # Add term
# Print the final result.
print(total_sum)
solve()
qwewe