結果
問題 |
No.1891 Static Xor Range Composite Query
|
ユーザー |
![]() |
提出日時 | 2025-04-09 20:58:10 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,193 bytes |
コンパイル時間 | 437 ms |
コンパイル使用メモリ | 82,352 KB |
実行使用メモリ | 215,060 KB |
最終ジャッジ日時 | 2025-04-09 21:00:55 |
合計ジャッジ時間 | 31,167 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 20 TLE * 1 -- * 9 |
ソースコード
import sys MOD = 998244353 def main(): sys.setrecursionlimit(1 << 25) N, Q = map(int, sys.stdin.readline().split()) a = [] b = [] for _ in range(N): ai, bi = map(int, sys.stdin.readline().split()) a.append(ai % MOD) b.append(bi % MOD) class SegmentTreeNode: def __init__(self, l, r): self.l = l self.r = r self.left = None self.right = None self.matrix = None def build(l, r): node = SegmentTreeNode(l, r) if l == r: node.matrix = (a[l], b[l]) else: mid = (l + r) // 2 node.left = build(l, mid) node.right = build(mid + 1, r) a1, b1 = node.left.matrix a2, b2 = node.right.matrix a_combined = (a2 * a1) % MOD b_combined = (a2 * b1 + b2) % MOD node.matrix = (a_combined, b_combined) return node root = build(0, N - 1) def query_segment(node, l, r): if node.r < l or node.l > r: return (1, 0) if l <= node.l and node.r <= r: return node.matrix left_a, left_b = query_segment(node.left, l, r) right_a, right_b = query_segment(node.right, l, r) combined_a = (right_a * left_a) % MOD combined_b = (right_a * left_b + right_b) % MOD return (combined_a, combined_b) def get_combined(l, r): if l > r: return (1, 0) return query_segment(root, l, r) def get_intervals(start_k, end_k, p): start_i = start_k ^ p end_i = end_k ^ p if start_i > end_i: start_i, end_i = end_i, start_i intervals = [] s = start_i e = end_i while s <= e: m = (s | ((1 << 30) - 1)) & (~((1 << (len(bin(s ^ e)) - 2)) - 1)) if m > e: m = e intervals.append((s, m)) s = m + 1 return intervals for _ in range(Q): l, r, p, x = map(int, sys.stdin.readline().split()) start_k = l end_k = r - 1 if start_k > end_k: print(x % MOD) continue all_intervals = [] current_start = start_k while current_start <= end_k: next_end = current_start | ((current_start ^ p) ^ ((current_start ^ p) | (current_start ^ p))) next_end = min(next_end, end_k) interval_start_i = current_start ^ p interval_end_i = next_end ^ p if interval_start_i > interval_end_i: interval_start_i, interval_end_i = interval_end_i, interval_start_i all_intervals.append((interval_start_i, interval_end_i)) current_start = next_end + 1 combined_a = 1 combined_b = 0 for (s, e) in all_intervals: if s > e: continue seg_a, seg_b = get_combined(s, e) combined_a = (combined_a * seg_a) % MOD combined_b = (seg_a * combined_b + seg_b) % MOD res = (combined_a * x + combined_b) % MOD print(res) if __name__ == '__main__': main()