結果
| 問題 |
No.1891 Static Xor Range Composite Query
|
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-09 20:58:10 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,193 bytes |
| コンパイル時間 | 437 ms |
| コンパイル使用メモリ | 82,352 KB |
| 実行使用メモリ | 215,060 KB |
| 最終ジャッジ日時 | 2025-04-09 21:00:55 |
| 合計ジャッジ時間 | 31,167 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 20 TLE * 1 -- * 9 |
ソースコード
import sys
MOD = 998244353
def main():
sys.setrecursionlimit(1 << 25)
N, Q = map(int, sys.stdin.readline().split())
a = []
b = []
for _ in range(N):
ai, bi = map(int, sys.stdin.readline().split())
a.append(ai % MOD)
b.append(bi % MOD)
class SegmentTreeNode:
def __init__(self, l, r):
self.l = l
self.r = r
self.left = None
self.right = None
self.matrix = None
def build(l, r):
node = SegmentTreeNode(l, r)
if l == r:
node.matrix = (a[l], b[l])
else:
mid = (l + r) // 2
node.left = build(l, mid)
node.right = build(mid + 1, r)
a1, b1 = node.left.matrix
a2, b2 = node.right.matrix
a_combined = (a2 * a1) % MOD
b_combined = (a2 * b1 + b2) % MOD
node.matrix = (a_combined, b_combined)
return node
root = build(0, N - 1)
def query_segment(node, l, r):
if node.r < l or node.l > r:
return (1, 0)
if l <= node.l and node.r <= r:
return node.matrix
left_a, left_b = query_segment(node.left, l, r)
right_a, right_b = query_segment(node.right, l, r)
combined_a = (right_a * left_a) % MOD
combined_b = (right_a * left_b + right_b) % MOD
return (combined_a, combined_b)
def get_combined(l, r):
if l > r:
return (1, 0)
return query_segment(root, l, r)
def get_intervals(start_k, end_k, p):
start_i = start_k ^ p
end_i = end_k ^ p
if start_i > end_i:
start_i, end_i = end_i, start_i
intervals = []
s = start_i
e = end_i
while s <= e:
m = (s | ((1 << 30) - 1)) & (~((1 << (len(bin(s ^ e)) - 2)) - 1))
if m > e:
m = e
intervals.append((s, m))
s = m + 1
return intervals
for _ in range(Q):
l, r, p, x = map(int, sys.stdin.readline().split())
start_k = l
end_k = r - 1
if start_k > end_k:
print(x % MOD)
continue
all_intervals = []
current_start = start_k
while current_start <= end_k:
next_end = current_start | ((current_start ^ p) ^ ((current_start ^ p) | (current_start ^ p)))
next_end = min(next_end, end_k)
interval_start_i = current_start ^ p
interval_end_i = next_end ^ p
if interval_start_i > interval_end_i:
interval_start_i, interval_end_i = interval_end_i, interval_start_i
all_intervals.append((interval_start_i, interval_end_i))
current_start = next_end + 1
combined_a = 1
combined_b = 0
for (s, e) in all_intervals:
if s > e:
continue
seg_a, seg_b = get_combined(s, e)
combined_a = (combined_a * seg_a) % MOD
combined_b = (seg_a * combined_b + seg_b) % MOD
res = (combined_a * x + combined_b) % MOD
print(res)
if __name__ == '__main__':
main()
lam6er