結果
| 問題 |
No.650 行列木クエリ
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-09 21:05:10 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 9,000 bytes |
| コンパイル時間 | 204 ms |
| コンパイル使用メモリ | 82,792 KB |
| 実行使用メモリ | 186,172 KB |
| 最終ジャッジ日時 | 2025-04-09 21:06:46 |
| 合計ジャッジ時間 | 4,794 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 3 WA * 7 |
ソースコード
import sys
from sys import stdin
from collections import deque
MOD = 10**9 + 7
def main():
sys.setrecursionlimit(1 << 25)
n = int(stdin.readline())
edges_input = []
adj = [[] for _ in range(n)]
for i in range(n-1):
a, b = map(int, stdin.readline().split())
edges_input.append((a, b))
adj[a].append(b)
adj[b].append(a)
parent = [ -1 ] * n
size = [1] * n
depth = [0] * n
# BFS from root 0 to compute parent, size, depth
q = deque([0])
parent[0] = -1
while q:
u = q.popleft()
for v in adj[u]:
if parent[v] == -1 and v != parent[u]:
parent[v] = u
depth[v] = depth[u] + 1
q.append(v)
# Compute subtree sizes
order = list(range(n))
order.sort(key=lambda x: -depth[x])
for u in order:
if parent[u] != -1:
size[parent[u]] += size[u]
# Determine edge_map and edge_for_node
edge_map = [None] * (n-1)
edge_for_node = [ -1 ] * n # edge_for_node[v] gives input edge index for parent[v] to v
for i in range(n-1):
a, b = edges_input[i]
if parent[b] == a:
u, v = a, b
elif parent[a] == b:
u, v = b, a
else:
assert False, "Invalid tree edge"
edge_map[i] = (u, v)
# Compute edge_for_node for each node except root
edge_for_node = [ -1 ] * n
for i in range(n-1):
u, v = edge_map[i]
if v != 0:
edge_for_node[v] = i
# HLD
head = [0] * n
pos = [0] * n # position in chain's edge list
chains = [] # each chain is a list of nodes in the chain's edge list
# Compute heavy child for each node
heavy = [ -1 ] * n
for u in order:
max_size = -1
for v in adj[u]:
if v == parent[u]:
continue
if size[v] > max_size:
max_size = size[v]
heavy[u] = v
# Perform HLD
current_pos = 0
for u in range(n):
if parent[u] == -1 or u != heavy[parent[u]]:
chain = []
v = u
while True:
head[v] = u
chain.append(v)
if v != u:
pos[v] = len(chain) -1
else:
# head's position is not part of edge list
pos[v] = -1 # or any invalid value, since head's edge is not present
if heavy[v] == -1:
break
v = heavy[v]
chains.append(chain)
# Now, build chains' edge lists and adjust positions
# Note: the chains list stores the nodes in each chain, with the head as the first element
# For each chain, the edge list consists of the nodes except the head (their edges are from parent)
# For example, chain [0, 2, 4], the edges are 0→2 (stored at 2's edge), 2→4 (stored at4's edge)
# Recompute head and pos based on chains built
# And build for each chain, the list of edge nodes (v)
chain_edges = []
chain_dict = {}
head = [0] * n
pos = [0] * n
for chain in chains:
head_node = chain[0]
edge_list = chain[1:] # the edges are for v in edge_list: parent[v] ->v
chain_edges.append(edge_list)
# Update head for all nodes in this chain
for v in chain:
head[v] = head_node
# update pos for nodes in edge_list
for idx, v in enumerate(edge_list):
pos[v] = idx
chain_dict[head_node] = (edge_list, len(edge_list))
# Build segment trees for each chain
class SegmentTree:
def __init__(self, size, initial_values, identity):
self.n = 1
while self.n < size:
self.n <<=1
self.size = size
self.tree = [identity] * (2 * self.n)
self.identity = identity
# Fill leaves
for i in range(size):
self.tree[self.n +i] = initial_values[i]
# Fill parents
for i in range(self.n-1, 0, -1):
self.tree[i] = multiply(self.tree[2*i], self.tree[2*i +1])
def update(self, pos, value):
pos += self.n
self.tree[pos] = value
pos >>=1
while pos >=1:
new_val = multiply(self.tree[2*pos], self.tree[2*pos +1])
if self.tree[pos] == new_val:
break
self.tree[pos] = new_val
pos >>=1
def query_range(self, l, r):
# query [l, r) interval
res = self.identity
l += self.n
r += self.n
while l < r:
if l %2 ==1:
res = multiply(res, self.tree[l])
l +=1
if r %2 ==1:
r -=1
res = multiply(res, self.tree[r])
l >>=1
r >>=1
return res
# Prepare for each chain: create segment tree
# Each chain's edge list is chain_edges[i], which is a list of nodes v, each representing parent[v]->v edge
# edge_for_node[v] gives the input edge index for their edge
# matrices are stored in edges_matrices (size n-1)
edges_matrices = [ [[1,0], [0,1]] for _ in range(n-1)] # initial identity matrix
chain_segment_trees = []
chain_info = {}
for i in range(len(chains)):
chain_head = chains[i][0]
edge_list = chain_edges[i] # list of nodes v in the edge list, which correspond to edges parent[v]→v
size_chain = len(edge_list)
initial = []
for v in edge_list:
edge_idx = edge_for_node[v]
mat = edges_matrices[edge_idx]
initial.append(mat)
# The identity matrix for the segment tree
identity = [[1,0], [0,1]]
st = SegmentTree(size_chain, initial, identity)
chain_info[chain_head] = (edge_list, st)
def update_edge(i, new_mat):
# get the child node v of edge i
u, v = edge_map[i]
# get chain info for v
head_v = head[v]
# check if v is in the edge list of this chain
chain_head = head_v
edge_list, st = chain_info[chain_head]
# find the index of v in edge_list
idx = pos[v]
# Update the segment tree
st.update(idx, new_mat)
def query_path(i, j):
# i is ancestor of j
res = [[1,0], [0,1]] # identity matrix
current = j
while True:
if head[current] != head[i]:
# Process the entire chain of current up to head
chain_head = head[current]
edge_list, st = chain_info[chain_head]
l = 0
r = pos[current] +1 # [0, pos[current]]
# Edge_list's indices from 0 to pos[current]
product = st.query_range(l, r)
# Multiply to the left of res
res = multiply(product, res)
current = parent[chain_head]
else:
# Process from i to current
if i == head[i]:
# current is in the same chain, i is the head
l = 0
if current == i:
r = 0
else:
r = pos[current] +1
else:
l = pos[i] +1
r = pos[current] +1
# Check if l <= r-1 (since r is exclusive)
if l < r:
chain_head = head[i]
edge_list, st = chain_info[chain_head]
product = st.query_range(l, r)
res = multiply(product, res)
break
return res
q = int(stdin.readline())
for _ in range(q):
parts = stdin.readline().split()
if not parts:
continue
if parts[0] == 'x':
_, i, x00, x01, x10, x11 = parts
i = int(i)
x00 = int(x00)
x01 = int(x01)
x10 = int(x10)
x11 = int(x11)
new_mat = [ [x00 % MOD, x01 % MOD], [x10 % MOD, x11 % MOD] ]
edges_matrices[i] = new_mat # Update the matrix
update_edge(i, new_mat)
elif parts[0] == 'g':
_, i, j = parts
i = int(i)
j = int(j)
mat = query_path(i, j)
# Flatten the matrix
print(mat[0][0], mat[0][1], mat[1][0], mat[1][1])
else:
assert False, "Invalid query"
def multiply(a, b):
# Compute matrix multiplication a * b
# a is m1, b is m2
c00 = (a[0][0] * b[0][0] + a[0][1] * b[1][0]) % MOD
c01 = (a[0][0] * b[0][1] + a[0][1] * b[1][1]) % MOD
c10 = (a[1][0] * b[0][0] + a[1][1] * b[1][0]) % MOD
c11 = (a[1][0] * b[0][1] + a[1][1] * b[1][1]) % MOD
return [ [c00, c01], [c10, c11] ]
if __name__ == '__main__':
main()
lam6er