結果

問題 No.650 行列木クエリ
ユーザー gew1fw
提出日時 2025-06-12 20:30:14
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,983 bytes
コンパイル時間 363 ms
コンパイル使用メモリ 82,264 KB
実行使用メモリ 225,524 KB
最終ジャッジ日時 2025-06-12 20:30:23
合計ジャッジ時間 8,455 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 7 TLE * 1 -- * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 10**9 + 7

def main():
    sys.setrecursionlimit(1 << 25)
    n = int(sys.stdin.readline())
    edges = [[] for _ in range(n)]
    for i in range(n-1):
        a, b = map(int, sys.stdin.readline().split())
        edges[a].append((b, i))
        edges[b].append((a, i))
    
    parent = [-1] * n
    edge_index = [-1] * n
    visited = [False] * n
    from collections import deque
    q = deque()
    q.append(0)
    visited[0] = True
    while q:
        u = q.popleft()
        for v, idx in edges[u]:
            if not visited[v]:
                visited[v] = True
                parent[v] = u
                edge_index[v] = idx
                q.append(v)
    
    size = [1] * n
    def dfs(u):
        for v, idx in edges[u]:
            if v != parent[u]:
                dfs(v)
                size[u] += size[v]
    dfs(0)
    
    heavy = [-1] * n
    for u in range(n):
        max_size = 0
        for v, idx in edges[u]:
            if v == parent[u]:
                continue
            if size[v] > max_size:
                max_size = size[v]
                heavy[u] = v
    
    chain_head = [0] * n
    order = [0] * n
    current_order = 0
    visited = [False] * n
    stack = [(0, -1, False)]
    while stack:
        u, p, processed = stack.pop()
        if processed:
            for v, idx in edges[u]:
                if v != p and v != heavy[u]:
                    stack.append((v, u, False))
            continue
        if visited[u]:
            continue
        visited[u] = True
        parent[u] = p
        stack.append((u, p, True))
        if heavy[u] != -1:
            chain_head[u] = chain_head[heavy[u]] if chain_head[heavy[u]] != 0 else u
            order[u] = current_order
            current_order += 1
            stack.append((heavy[u], u, False))
    
    edge_orders = [0] * (n-1)
    for u in range(1, n):
        e = edge_index[u]
        edge_orders[e] = order[u]
    
    edges_sorted = sorted([(edge_orders[i], i) for i in range(n-1)])
    edge_to_pos = [0] * (n-1)
    for pos, (eo, i) in enumerate(edges_sorted):
        edge_to_pos[i] = pos
    
    class SegmentTree:
        def __init__(self, data):
            self.n = len(data)
            self.size = 1
            while self.size < self.n:
                self.size <<=1
            self.tree = [[[1,0],[0,1]] for _ in range(2*self.size)]
            for i in range(self.n):
                self.tree[self.size + i] = data[i]
            for i in range(self.size-1, 0, -1):
                self.tree[i] = self.mul(self.tree[2*i], self.tree[2*i+1])
        
        def update(self, pos, value):
            pos += self.size
            self.tree[pos] = value
            pos >>=1
            while pos >=1:
                left = self.tree[2*pos]
                right = self.tree[2*pos+1]
                self.tree[pos] = self.mul(left, right)
                pos >>=1
        
        def query(self, l, r):
            res = [[1,0],[0,1]]
            l += self.size
            r += self.size
            while l <= r:
                if l %2 ==1:
                    res = self.mul(res, self.tree[l])
                    l +=1
                if r %2 ==0:
                    res = self.mul(res, self.tree[r])
                    r -=1
                l >>=1
                r >>=1
            return res
        
        @staticmethod
        def mul(a, b):
            new_a = [ [0]*2 for _ in range(2)]
            for i in range(2):
                for j in range(2):
                    new_a[i][j] = (a[i][0] * b[0][j] + a[i][1] * b[1][j]) % MOD
            return new_a
    
    initial = [ [[1,0],[0,1]] ] * (n-1)
    st = SegmentTree(initial)
    
    q = int(sys.stdin.readline())
    for _ in range(q):
        parts = sys.stdin.readline().split()
        if parts[0] == 'x':
            i = int(parts[1])
            x00 = int(parts[2])
            x01 = int(parts[3])
            x10 = int(parts[4])
            x11 = int(parts[5])
            pos = edge_to_pos[i]
            mat = [[x00, x01], [x10, x11]]
            st.update(pos, mat)
        else:
            i = int(parts[1])
            j = int(parts[2])
            u = j
            path = []
            while u != i:
                e = edge_index[u]
                pos = edge_to_pos[e]
                path.append(pos)
                u = parent[u]
            def multiply(mats):
                result = [[1,0], [0,1]]
                for m in mats:
                    result = st.mul(result, m)
                return result
            mats = []
            for p in path:
                e = edge_index[j]
                mat = st.query(p, p)
                mats.append(mat)
                j = parent[j]
            result = [[1,0], [0,1]]
            for m in reversed(mats):
                result = st.mul(result, m)
            print(' '.join(map(str, [result[0][0], result[0][1], result[1][0], result[1][1]])))

if __name__ == '__main__':
    main()
0