結果
| 問題 | No.1796 木上のクーロン |
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 13:19:17 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 13,105 bytes |
| コンパイル時間 | 309 ms |
| コンパイル使用メモリ | 82,796 KB |
| 実行使用メモリ | 278,700 KB |
| 最終ジャッジ日時 | 2025-05-14 13:20:32 |
| 合計ジャッジ時間 | 27,458 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 17 TLE * 1 -- * 16 |
ソースコード
import sys
# Set higher recursion depth limit for deep trees or complex centroid decompositions.
# The limit needs to be sufficiently large for N up to 200000.
try:
# Set recursion limit slightly above max N to accommodate call stack depth.
sys.setrecursionlimit(200000 + 10)
except Exception:
# If setting recursion limit fails (e.g., due to OS restrictions),
# the code might fail on deep trees. Standard Python limits are often around 1000.
pass
def solve():
N = int(sys.stdin.readline())
Q = list(map(int, sys.stdin.readline().split()))
adj = [[] for _ in range(N)]
for _ in range(N - 1):
# Read edges and build adjacency list (0-based indexing)
u, v = map(int, sys.stdin.readline().split())
adj[u - 1].append(v - 1)
adj[v - 1].append(u - 1)
MOD = 998244353
# Precompute factorials modulo MOD
fact = [1] * (N + 1)
for i in range(1, N + 1):
fact[i] = (fact[i - 1] * i) % MOD
# Calculate k0 = (N!)^2 mod MOD
k0 = pow(fact[N], 2, MOD)
# Precompute g(k) = 1 / (k+1)^2 mod MOD for necessary range.
# The maximum argument needed for g is sum of two distances. Max distance is N-1.
# So max argument is 2*(N-1) = 2N-2.
max_dist_arg = 2 * N - 2
# Array to store g[k] values
g = [0] * (max_dist_arg + 1) # Need indices 0 to max_dist_arg
# Calculate g(k) values only if N >= 1 (implies max_dist_arg >= -2)
if max_dist_arg >= 0:
# Calculate values (k+1)^2 mod MOD to be inverted
vals_to_inv = []
for k in range(max_dist_arg + 1):
# Calculate (k+1)^2 % MOD. Note k+1 ranges from 1 to 2N-1.
val = pow(k + 1, 2, MOD)
# Since MOD is prime and larger than 2N-1 for N<=200000, k+1 is never 0 mod MOD.
# Thus (k+1)^2 is also never 0 mod MOD.
vals_to_inv.append(val)
# Use Batch modular inverse technique for efficiency
if vals_to_inv: # Check if the list is not empty
# Calculate prefix products
prefix_prod = [1] * (len(vals_to_inv) + 1)
for i in range(len(vals_to_inv)):
prefix_prod[i+1] = (prefix_prod[i] * vals_to_inv[i]) % MOD
# Calculate inverse of the product of all values using Fermat's Little Theorem
inv_all = pow(prefix_prod[len(vals_to_inv)], MOD - 2, MOD)
# Compute individual inverses using prefix products and the total inverse
# Iterate backwards to compute inv[i] = prefix[i] * inv_all_suffix[i+1]
for i in range(len(vals_to_inv) - 1, -1, -1):
# g[k] stores inverse of (k+1)^2 mod MOD
g[i] = (prefix_prod[i] * inv_all) % MOD
# Update inv_all for the next iteration (effectively multiply by vals_to_inv[i]^{-1})
inv_all = (inv_all * vals_to_inv[i]) % MOD
# NTT (Number Theoretic Transform) implementation
# Using standard Cooley-Tukey Radix-2 FFT algorithm adapted for modular arithmetic
primitive_root = 3 # Primitive root for MOD = 998244353
def ntt(a, inv):
""" Performs NTT or inverse NTT on array a """
n = len(a)
# Bit reversal permutation
j = 0
for i in range(1, n):
rev = n >> 1
while j >= rev: j -= rev; rev >>= 1
j += rev
if i < j: a[i], a[j] = a[j], a[i]
# Butterfly operations stages
len_ = 2 # Current DFT size
while len_ <= n:
# Calculate primitive len_-th root of unity
w_len = pow(primitive_root, (MOD - 1) // len_, MOD)
if inv: # For inverse NTT, use inverse root
w_len = pow(w_len, MOD - 2, MOD)
# Iterate through blocks of size len_
for i in range(0, n, len_):
w = 1 # Current power of the root of unity
# Process pairs within the block
for k in range(i, i + len_ // 2):
u = a[k]
idx_k_half = k + len_ // 2
v = (w * a[idx_k_half]) % MOD
a[k] = (u + v) % MOD
a[idx_k_half] = (u - v + MOD) % MOD # Ensure result is non-negative
w = (w * w_len) % MOD # Update power for next pair
len_ <<= 1 # Double DFT size for next stage
# Scale by 1/n for inverse NTT
if inv:
n_inv = pow(n, MOD - 2, MOD)
for i in range(n): a[i] = (a[i] * n_inv) % MOD
return a
def multiply(p1, p2):
""" Multiplies two polynomials p1 and p2 using NTT """
s1 = len(p1); s2 = len(p2)
# Determine required FFT size (power of 2 >= final degree + 1)
n = 1
# Final polynomial degree is (s1-1) + (s2-1) = s1+s2-2. Length is s1+s2-1.
while n < s1 + s2 - 1: n <<= 1
# Pad polynomials with zeros to size n
p1_fft = p1 + [0] * (n - s1)
p2_fft = p2 + [0] * (n - s2)
# Transform polynomials to point-value representation
ntt(p1_fft, False)
ntt(p2_fft, False)
# Pointwise multiplication in frequency domain
res_fft = [(p1_fft[i] * p2_fft[i]) % MOD for i in range(n)]
# Transform result back to coefficient representation
ntt(res_fft, True)
# Return the resulting polynomial coefficients, truncated to the actual size s1+s2-1
return res_fft[:s1 + s2 -1]
# Precompute polynomial G(x) = sum g(j) x^j
# Its coefficients are g[0], g[1], ..., g[max_dist_arg]
G_poly = g[:max_dist_arg+1]
# Array to store final answers E_p for each p
ans = [0] * N
# State variables for Centroid Decomposition
active = [True] * N # Tracks nodes currently part of the component being processed
subtree_sizes = [0] * N # Stores subtree sizes temporarily during centroid finding
# --- Centroid Decomposition Helper Functions ---
def get_subtree_sizes_cd(u, p):
""" Computes subtree sizes rooted at u, considering only active nodes """
s = 1
for v in adj[u]:
# Explore neighbor v if it's not the parent and is active
if v != p and active[v]:
s += get_subtree_sizes_cd(v, u)
subtree_sizes[u] = s
return s
def find_centroid_cd(u, p, total_size):
""" Finds the centroid of the component rooted at u """
for v in adj[u]:
# Check if child v is active and not parent p
if v != p and active[v]:
# If size of subtree rooted at v is > half total size, centroid must be in that subtree
if subtree_sizes[v] * 2 > total_size:
return find_centroid_cd(v, u, total_size)
# If no such child exists, u is the centroid
return u
def bfs_distances_cd(c):
""" Performs BFS from centroid c to find distances within the current active component """
q = [(c, 0)]; visited = {c}; dist_from_c = {c: 0}
component_nodes_data = [(c, 0)] # List of (node_idx, dist_from_c)
max_D = 0 # Maximum distance found from centroid c
head = 0
while head < len(q):
curr, d = q[head]; head += 1
max_D = max(max_D, d)
# Explore neighbors
for neighbor in adj[curr]:
# Check if neighbor is active and not visited yet
if active[neighbor] and neighbor not in visited:
visited.add(neighbor); dist_from_c[neighbor] = d + 1
q.append((neighbor, d+1))
component_nodes_data.append((neighbor, d+1))
return component_nodes_data, dist_from_c, max_D
# --- Main Recursive Centroid Decomposition Function ---
def cd_solve(nodes_indices):
""" Recursive function to process a component defined by nodes_indices """
# Base case: Component is empty
if not nodes_indices: return
# Find an active node to start size computation and centroid finding
start_node = -1
for idx in nodes_indices:
if active[idx]:
start_node = idx
break
# If no active node found (should not happen if nodes_indices is not empty), return.
if start_node == -1: return
# Compute sizes and find centroid c for the current component
total_size = get_subtree_sizes_cd(start_node, -1)
c = find_centroid_cd(start_node, -1, total_size)
# Get distances and node data from centroid c via BFS
component_nodes_data, dist_from_c, max_D = bfs_distances_cd(c)
# Compute S_k = sum_{i in component, dist(c,i)=k} Q[i]
S_k = [0] * (max_D + 1)
for node_idx, d in component_nodes_data: S_k[d] = (S_k[d] + Q[node_idx]) % MOD
# Construct polynomial P(x) = sum S_k x^{max_D - k}
# Coefficients are [S_maxD, S_{maxD-1}, ..., S_0]
P_coeffs = [0] * (max_D + 1)
for k in range(max_D + 1): P_coeffs[k] = S_k[max_D - k]
# Compute C(x) = P(x) * G(x) using NTT based polynomial multiplication
C_coeffs = multiply(P_coeffs, G_poly)
# Calculate the contribution for the centroid node c itself
# This is Sum_{i in component} k0 * Q[i] * g(dist(c,i))
contrib_c = 0
for node_idx, d in component_nodes_data: contrib_c = (contrib_c + Q[node_idx] * g[d]) % MOD
# Add this contribution to the final answer for node c
ans[c] = (ans[c] + k0 * contrib_c) % MOD
# Deactivate centroid c for processing subproblems
active[c] = False
# Iterate through neighbors of c to identify subtrees/subcomponents
for neighbor in adj[c]:
if active[neighbor]: # Check if this neighbor leads to an active part
# Collect all nodes in the subtree component starting at neighbor
# Use BFS starting from neighbor, restricted to active nodes
subtree_nodes = []
q_subtree = [neighbor]; visited_subtree = {neighbor}
head_subtree = 0
while head_subtree < len(q_subtree):
curr_sub = q_subtree[head_subtree]; subtree_nodes.append(curr_sub); head_subtree += 1
for next_node in adj[curr_sub]:
# Explore only active nodes not visited in this specific BFS
if active[next_node] and next_node not in visited_subtree:
visited_subtree.add(next_node); q_subtree.append(next_node)
# If the component connected through neighbor is empty, continue
if not subtree_nodes: continue
# Compute S_k(v) = sum_{i in Tv, dist(c,i)=k} Q[i] for this subtree Tv
S_k_v = [0] * (max_D + 1)
for node_idx in subtree_nodes:
# Check if node_idx was reached from c (should always be true if active)
if node_idx in dist_from_c:
d = dist_from_c[node_idx]
# Ensure distance is within bounds
if d < len(S_k_v): S_k_v[d] = (S_k_v[d] + Q[node_idx]) % MOD
# Construct polynomial Pv(x) = sum Sk(v) x^{max_D - k} for the subtree
P_v_coeffs = [0] * (max_D + 1)
for k in range(max_D + 1): P_v_coeffs[k] = S_k_v[max_D - k]
# Compute Cv(x) = Pv(x) * G(x) using NTT
C_v_coeffs = multiply(P_v_coeffs, G_poly)
# Calculate and add contributions for nodes p in this subtree Tv
for p in subtree_nodes:
# Check if node p was reached from c
if p in dist_from_c:
dp = dist_from_c[p] # Distance from p to centroid c
idx = dp + max_D # Required coefficient index in C(x) and Cv(x)
# Safely get coefficients from result polynomials
term1 = C_coeffs[idx] if idx < len(C_coeffs) else 0
term2 = C_v_coeffs[idx] if idx < len(C_v_coeffs) else 0
# Calculate contribution Delta E_p = k0 * (term1 - term2)
delta_Ep = (term1 - term2 + MOD) % MOD
# Add to the final answer for node p
ans[p] = (ans[p] + k0 * delta_Ep) % MOD
# Recursively call centroid decomposition on this subtree component
cd_solve(subtree_nodes)
# Initial call to start the Centroid Decomposition process on the whole tree
initial_nodes = list(range(N))
cd_solve(initial_nodes)
# Output the final computed E_p values for all p
for res in ans:
print(res)
# Run the solution function
solve()
qwewe