結果

問題 No.1891 Static Xor Range Composite Query
ユーザー lam6er
提出日時 2025-04-09 20:58:10
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,193 bytes
コンパイル時間 437 ms
コンパイル使用メモリ 82,352 KB
実行使用メモリ 215,060 KB
最終ジャッジ日時 2025-04-09 21:00:55
合計ジャッジ時間 31,167 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 20 TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def main():
    sys.setrecursionlimit(1 << 25)
    N, Q = map(int, sys.stdin.readline().split())
    a = []
    b = []
    for _ in range(N):
        ai, bi = map(int, sys.stdin.readline().split())
        a.append(ai % MOD)
        b.append(bi % MOD)
    
    class SegmentTreeNode:
        def __init__(self, l, r):
            self.l = l
            self.r = r
            self.left = None
            self.right = None
            self.matrix = None
    
    def build(l, r):
        node = SegmentTreeNode(l, r)
        if l == r:
            node.matrix = (a[l], b[l])
        else:
            mid = (l + r) // 2
            node.left = build(l, mid)
            node.right = build(mid + 1, r)
            a1, b1 = node.left.matrix
            a2, b2 = node.right.matrix
            a_combined = (a2 * a1) % MOD
            b_combined = (a2 * b1 + b2) % MOD
            node.matrix = (a_combined, b_combined)
        return node
    
    root = build(0, N - 1)
    
    def query_segment(node, l, r):
        if node.r < l or node.l > r:
            return (1, 0)
        if l <= node.l and node.r <= r:
            return node.matrix
        left_a, left_b = query_segment(node.left, l, r)
        right_a, right_b = query_segment(node.right, l, r)
        combined_a = (right_a * left_a) % MOD
        combined_b = (right_a * left_b + right_b) % MOD
        return (combined_a, combined_b)
    
    def get_combined(l, r):
        if l > r:
            return (1, 0)
        return query_segment(root, l, r)
    
    def get_intervals(start_k, end_k, p):
        start_i = start_k ^ p
        end_i = end_k ^ p
        if start_i > end_i:
            start_i, end_i = end_i, start_i
        intervals = []
        s = start_i
        e = end_i
        while s <= e:
            m = (s | ((1 << 30) - 1)) & (~((1 << (len(bin(s ^ e)) - 2)) - 1))
            if m > e:
                m = e
            intervals.append((s, m))
            s = m + 1
        return intervals
    
    for _ in range(Q):
        l, r, p, x = map(int, sys.stdin.readline().split())
        start_k = l
        end_k = r - 1
        if start_k > end_k:
            print(x % MOD)
            continue
        all_intervals = []
        current_start = start_k
        while current_start <= end_k:
            next_end = current_start | ((current_start ^ p) ^ ((current_start ^ p) | (current_start ^ p)))
            next_end = min(next_end, end_k)
            interval_start_i = current_start ^ p
            interval_end_i = next_end ^ p
            if interval_start_i > interval_end_i:
                interval_start_i, interval_end_i = interval_end_i, interval_start_i
            all_intervals.append((interval_start_i, interval_end_i))
            current_start = next_end + 1
        combined_a = 1
        combined_b = 0
        for (s, e) in all_intervals:
            if s > e:
                continue
            seg_a, seg_b = get_combined(s, e)
            combined_a = (combined_a * seg_a) % MOD
            combined_b = (seg_a * combined_b + seg_b) % MOD
        res = (combined_a * x + combined_b) % MOD
        print(res)

if __name__ == '__main__':
    main()
0