結果
問題 |
No.650 行列木クエリ
|
ユーザー |
![]() |
提出日時 | 2025-04-09 21:05:10 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 9,000 bytes |
コンパイル時間 | 204 ms |
コンパイル使用メモリ | 82,792 KB |
実行使用メモリ | 186,172 KB |
最終ジャッジ日時 | 2025-04-09 21:06:46 |
合計ジャッジ時間 | 4,794 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 3 WA * 7 |
ソースコード
import sys from sys import stdin from collections import deque MOD = 10**9 + 7 def main(): sys.setrecursionlimit(1 << 25) n = int(stdin.readline()) edges_input = [] adj = [[] for _ in range(n)] for i in range(n-1): a, b = map(int, stdin.readline().split()) edges_input.append((a, b)) adj[a].append(b) adj[b].append(a) parent = [ -1 ] * n size = [1] * n depth = [0] * n # BFS from root 0 to compute parent, size, depth q = deque([0]) parent[0] = -1 while q: u = q.popleft() for v in adj[u]: if parent[v] == -1 and v != parent[u]: parent[v] = u depth[v] = depth[u] + 1 q.append(v) # Compute subtree sizes order = list(range(n)) order.sort(key=lambda x: -depth[x]) for u in order: if parent[u] != -1: size[parent[u]] += size[u] # Determine edge_map and edge_for_node edge_map = [None] * (n-1) edge_for_node = [ -1 ] * n # edge_for_node[v] gives input edge index for parent[v] to v for i in range(n-1): a, b = edges_input[i] if parent[b] == a: u, v = a, b elif parent[a] == b: u, v = b, a else: assert False, "Invalid tree edge" edge_map[i] = (u, v) # Compute edge_for_node for each node except root edge_for_node = [ -1 ] * n for i in range(n-1): u, v = edge_map[i] if v != 0: edge_for_node[v] = i # HLD head = [0] * n pos = [0] * n # position in chain's edge list chains = [] # each chain is a list of nodes in the chain's edge list # Compute heavy child for each node heavy = [ -1 ] * n for u in order: max_size = -1 for v in adj[u]: if v == parent[u]: continue if size[v] > max_size: max_size = size[v] heavy[u] = v # Perform HLD current_pos = 0 for u in range(n): if parent[u] == -1 or u != heavy[parent[u]]: chain = [] v = u while True: head[v] = u chain.append(v) if v != u: pos[v] = len(chain) -1 else: # head's position is not part of edge list pos[v] = -1 # or any invalid value, since head's edge is not present if heavy[v] == -1: break v = heavy[v] chains.append(chain) # Now, build chains' edge lists and adjust positions # Note: the chains list stores the nodes in each chain, with the head as the first element # For each chain, the edge list consists of the nodes except the head (their edges are from parent) # For example, chain [0, 2, 4], the edges are 0→2 (stored at 2's edge), 2→4 (stored at4's edge) # Recompute head and pos based on chains built # And build for each chain, the list of edge nodes (v) chain_edges = [] chain_dict = {} head = [0] * n pos = [0] * n for chain in chains: head_node = chain[0] edge_list = chain[1:] # the edges are for v in edge_list: parent[v] ->v chain_edges.append(edge_list) # Update head for all nodes in this chain for v in chain: head[v] = head_node # update pos for nodes in edge_list for idx, v in enumerate(edge_list): pos[v] = idx chain_dict[head_node] = (edge_list, len(edge_list)) # Build segment trees for each chain class SegmentTree: def __init__(self, size, initial_values, identity): self.n = 1 while self.n < size: self.n <<=1 self.size = size self.tree = [identity] * (2 * self.n) self.identity = identity # Fill leaves for i in range(size): self.tree[self.n +i] = initial_values[i] # Fill parents for i in range(self.n-1, 0, -1): self.tree[i] = multiply(self.tree[2*i], self.tree[2*i +1]) def update(self, pos, value): pos += self.n self.tree[pos] = value pos >>=1 while pos >=1: new_val = multiply(self.tree[2*pos], self.tree[2*pos +1]) if self.tree[pos] == new_val: break self.tree[pos] = new_val pos >>=1 def query_range(self, l, r): # query [l, r) interval res = self.identity l += self.n r += self.n while l < r: if l %2 ==1: res = multiply(res, self.tree[l]) l +=1 if r %2 ==1: r -=1 res = multiply(res, self.tree[r]) l >>=1 r >>=1 return res # Prepare for each chain: create segment tree # Each chain's edge list is chain_edges[i], which is a list of nodes v, each representing parent[v]->v edge # edge_for_node[v] gives the input edge index for their edge # matrices are stored in edges_matrices (size n-1) edges_matrices = [ [[1,0], [0,1]] for _ in range(n-1)] # initial identity matrix chain_segment_trees = [] chain_info = {} for i in range(len(chains)): chain_head = chains[i][0] edge_list = chain_edges[i] # list of nodes v in the edge list, which correspond to edges parent[v]→v size_chain = len(edge_list) initial = [] for v in edge_list: edge_idx = edge_for_node[v] mat = edges_matrices[edge_idx] initial.append(mat) # The identity matrix for the segment tree identity = [[1,0], [0,1]] st = SegmentTree(size_chain, initial, identity) chain_info[chain_head] = (edge_list, st) def update_edge(i, new_mat): # get the child node v of edge i u, v = edge_map[i] # get chain info for v head_v = head[v] # check if v is in the edge list of this chain chain_head = head_v edge_list, st = chain_info[chain_head] # find the index of v in edge_list idx = pos[v] # Update the segment tree st.update(idx, new_mat) def query_path(i, j): # i is ancestor of j res = [[1,0], [0,1]] # identity matrix current = j while True: if head[current] != head[i]: # Process the entire chain of current up to head chain_head = head[current] edge_list, st = chain_info[chain_head] l = 0 r = pos[current] +1 # [0, pos[current]] # Edge_list's indices from 0 to pos[current] product = st.query_range(l, r) # Multiply to the left of res res = multiply(product, res) current = parent[chain_head] else: # Process from i to current if i == head[i]: # current is in the same chain, i is the head l = 0 if current == i: r = 0 else: r = pos[current] +1 else: l = pos[i] +1 r = pos[current] +1 # Check if l <= r-1 (since r is exclusive) if l < r: chain_head = head[i] edge_list, st = chain_info[chain_head] product = st.query_range(l, r) res = multiply(product, res) break return res q = int(stdin.readline()) for _ in range(q): parts = stdin.readline().split() if not parts: continue if parts[0] == 'x': _, i, x00, x01, x10, x11 = parts i = int(i) x00 = int(x00) x01 = int(x01) x10 = int(x10) x11 = int(x11) new_mat = [ [x00 % MOD, x01 % MOD], [x10 % MOD, x11 % MOD] ] edges_matrices[i] = new_mat # Update the matrix update_edge(i, new_mat) elif parts[0] == 'g': _, i, j = parts i = int(i) j = int(j) mat = query_path(i, j) # Flatten the matrix print(mat[0][0], mat[0][1], mat[1][0], mat[1][1]) else: assert False, "Invalid query" def multiply(a, b): # Compute matrix multiplication a * b # a is m1, b is m2 c00 = (a[0][0] * b[0][0] + a[0][1] * b[1][0]) % MOD c01 = (a[0][0] * b[0][1] + a[0][1] * b[1][1]) % MOD c10 = (a[1][0] * b[0][0] + a[1][1] * b[1][0]) % MOD c11 = (a[1][0] * b[0][1] + a[1][1] * b[1][1]) % MOD return [ [c00, c01], [c10, c11] ] if __name__ == '__main__': main()