結果

問題 No.650 行列木クエリ
ユーザー lam6er
提出日時 2025-04-09 21:05:10
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 9,000 bytes
コンパイル時間 204 ms
コンパイル使用メモリ 82,792 KB
実行使用メモリ 186,172 KB
最終ジャッジ日時 2025-04-09 21:06:46
合計ジャッジ時間 4,794 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 3 WA * 7
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import deque

MOD = 10**9 + 7

def main():
    sys.setrecursionlimit(1 << 25)
    n = int(stdin.readline())
    edges_input = []
    adj = [[] for _ in range(n)]
    for i in range(n-1):
        a, b = map(int, stdin.readline().split())
        edges_input.append((a, b))
        adj[a].append(b)
        adj[b].append(a)

    parent = [ -1 ] * n
    size = [1] * n
    depth = [0] * n
    # BFS from root 0 to compute parent, size, depth
    q = deque([0])
    parent[0] = -1
    while q:
        u = q.popleft()
        for v in adj[u]:
            if parent[v] == -1 and v != parent[u]:
                parent[v] = u
                depth[v] = depth[u] + 1
                q.append(v)
    # Compute subtree sizes
    order = list(range(n))
    order.sort(key=lambda x: -depth[x])
    for u in order:
        if parent[u] != -1:
            size[parent[u]] += size[u]

    # Determine edge_map and edge_for_node
    edge_map = [None] * (n-1)
    edge_for_node = [ -1 ] * n  # edge_for_node[v] gives input edge index for parent[v] to v
    for i in range(n-1):
        a, b = edges_input[i]
        if parent[b] == a:
            u, v = a, b
        elif parent[a] == b:
            u, v = b, a
        else:
            assert False, "Invalid tree edge"
        edge_map[i] = (u, v)

    # Compute edge_for_node for each node except root
    edge_for_node = [ -1 ] * n
    for i in range(n-1):
        u, v = edge_map[i]
        if v != 0:
            edge_for_node[v] = i

    # HLD
    head = [0] * n
    pos = [0] * n  # position in chain's edge list
    chains = []  # each chain is a list of nodes in the chain's edge list

    # Compute heavy child for each node
    heavy = [ -1 ] * n
    for u in order:
        max_size = -1
        for v in adj[u]:
            if v == parent[u]:
                continue
            if size[v] > max_size:
                max_size = size[v]
                heavy[u] = v

    # Perform HLD
    current_pos = 0
    for u in range(n):
        if parent[u] == -1 or u != heavy[parent[u]]:
            chain = []
            v = u
            while True:
                head[v] = u
                chain.append(v)
                if v != u:
                    pos[v] = len(chain) -1
                else:
                    # head's position is not part of edge list
                    pos[v] = -1  # or any invalid value, since head's edge is not present
                if heavy[v] == -1:
                    break
                v = heavy[v]
            chains.append(chain)

    # Now, build chains' edge lists and adjust positions
    # Note: the chains list stores the nodes in each chain, with the head as the first element
    # For each chain, the edge list consists of the nodes except the head (their edges are from parent)
    # For example, chain [0, 2, 4], the edges are 0→2 (stored at 2's edge), 2→4 (stored at4's edge)

    # Recompute head and pos based on chains built
    # And build for each chain, the list of edge nodes (v)
    chain_edges = []
    chain_dict = {}
    head = [0] * n
    pos = [0] * n
    for chain in chains:
        head_node = chain[0]
        edge_list = chain[1:]  # the edges are for v in edge_list: parent[v] ->v
        chain_edges.append(edge_list)
        # Update head for all nodes in this chain
        for v in chain:
            head[v] = head_node
        # update pos for nodes in edge_list
        for idx, v in enumerate(edge_list):
            pos[v] = idx
        chain_dict[head_node] = (edge_list, len(edge_list))

    # Build segment trees for each chain
    class SegmentTree:
        def __init__(self, size, initial_values, identity):
            self.n = 1
            while self.n < size:
                self.n <<=1
            self.size = size
            self.tree = [identity] * (2 * self.n)
            self.identity = identity
            # Fill leaves
            for i in range(size):
                self.tree[self.n +i] = initial_values[i]
            # Fill parents
            for i in range(self.n-1, 0, -1):
                self.tree[i] = multiply(self.tree[2*i], self.tree[2*i +1])
        def update(self, pos, value):
            pos += self.n
            self.tree[pos] = value
            pos >>=1
            while pos >=1:
                new_val = multiply(self.tree[2*pos], self.tree[2*pos +1])
                if self.tree[pos] == new_val:
                    break
                self.tree[pos] = new_val
                pos >>=1
        def query_range(self, l, r):
            # query [l, r) interval
            res = self.identity
            l += self.n
            r += self.n
            while l < r:
                if l %2 ==1:
                    res = multiply(res, self.tree[l])
                    l +=1
                if r %2 ==1:
                    r -=1
                    res = multiply(res, self.tree[r])
                l >>=1
                r >>=1
            return res

    # Prepare for each chain: create segment tree
    # Each chain's edge list is chain_edges[i], which is a list of nodes v, each representing parent[v]->v edge
    # edge_for_node[v] gives the input edge index for their edge
    # matrices are stored in edges_matrices (size n-1)
    edges_matrices = [ [[1,0], [0,1]] for _ in range(n-1)]  # initial identity matrix

    chain_segment_trees = []
    chain_info = {}
    for i in range(len(chains)):
        chain_head = chains[i][0]
        edge_list = chain_edges[i]  # list of nodes v in the edge list, which correspond to edges parent[v]→v
        size_chain = len(edge_list)
        initial = []
        for v in edge_list:
            edge_idx = edge_for_node[v]
            mat = edges_matrices[edge_idx]
            initial.append(mat)
        # The identity matrix for the segment tree
        identity = [[1,0], [0,1]]
        st = SegmentTree(size_chain, initial, identity)
        chain_info[chain_head] = (edge_list, st)

    def update_edge(i, new_mat):
        # get the child node v of edge i
        u, v = edge_map[i]
        # get chain info for v
        head_v = head[v]
        # check if v is in the edge list of this chain
        chain_head = head_v
        edge_list, st = chain_info[chain_head]
        # find the index of v in edge_list
        idx = pos[v]
        # Update the segment tree
        st.update(idx, new_mat)

    def query_path(i, j):
        # i is ancestor of j
        res = [[1,0], [0,1]]  # identity matrix
        current = j
        while True:
            if head[current] != head[i]:
                # Process the entire chain of current up to head
                chain_head = head[current]
                edge_list, st = chain_info[chain_head]
                l = 0
                r = pos[current] +1  # [0, pos[current]]
                # Edge_list's indices from 0 to pos[current]
                product = st.query_range(l, r)
                # Multiply to the left of res
                res = multiply(product, res)
                current = parent[chain_head]
            else:
                # Process from i to current
                if i == head[i]:
                    # current is in the same chain, i is the head
                    l = 0
                    if current == i:
                        r = 0
                    else:
                        r = pos[current] +1
                else:
                    l = pos[i] +1
                    r = pos[current] +1
                # Check if l <= r-1 (since r is exclusive)
                if l < r:
                    chain_head = head[i]
                    edge_list, st = chain_info[chain_head]
                    product = st.query_range(l, r)
                    res = multiply(product, res)
                break
        return res

    q = int(stdin.readline())
    for _ in range(q):
        parts = stdin.readline().split()
        if not parts:
            continue
        if parts[0] == 'x':
            _, i, x00, x01, x10, x11 = parts
            i = int(i)
            x00 = int(x00)
            x01 = int(x01)
            x10 = int(x10)
            x11 = int(x11)
            new_mat = [ [x00 % MOD, x01 % MOD], [x10 % MOD, x11 % MOD] ]
            edges_matrices[i] = new_mat  # Update the matrix
            update_edge(i, new_mat)
        elif parts[0] == 'g':
            _, i, j = parts
            i = int(i)
            j = int(j)
            mat = query_path(i, j)
            # Flatten the matrix
            print(mat[0][0], mat[0][1], mat[1][0], mat[1][1])
        else:
            assert False, "Invalid query"

def multiply(a, b):
    # Compute matrix multiplication a * b
    # a is m1, b is m2
    c00 = (a[0][0] * b[0][0] + a[0][1] * b[1][0]) % MOD
    c01 = (a[0][0] * b[0][1] + a[0][1] * b[1][1]) % MOD
    c10 = (a[1][0] * b[0][0] + a[1][1] * b[1][0]) % MOD
    c11 = (a[1][0] * b[0][1] + a[1][1] * b[1][1]) % MOD
    return [ [c00, c01], [c10, c11] ]

if __name__ == '__main__':
    main()
0