結果
問題 | No.1796 木上のクーロン |
ユーザー |
![]() |
提出日時 | 2025-05-14 13:19:17 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 13,105 bytes |
コンパイル時間 | 309 ms |
コンパイル使用メモリ | 82,796 KB |
実行使用メモリ | 278,700 KB |
最終ジャッジ日時 | 2025-05-14 13:20:32 |
合計ジャッジ時間 | 27,458 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 17 TLE * 1 -- * 16 |
ソースコード
import sys # Set higher recursion depth limit for deep trees or complex centroid decompositions. # The limit needs to be sufficiently large for N up to 200000. try: # Set recursion limit slightly above max N to accommodate call stack depth. sys.setrecursionlimit(200000 + 10) except Exception: # If setting recursion limit fails (e.g., due to OS restrictions), # the code might fail on deep trees. Standard Python limits are often around 1000. pass def solve(): N = int(sys.stdin.readline()) Q = list(map(int, sys.stdin.readline().split())) adj = [[] for _ in range(N)] for _ in range(N - 1): # Read edges and build adjacency list (0-based indexing) u, v = map(int, sys.stdin.readline().split()) adj[u - 1].append(v - 1) adj[v - 1].append(u - 1) MOD = 998244353 # Precompute factorials modulo MOD fact = [1] * (N + 1) for i in range(1, N + 1): fact[i] = (fact[i - 1] * i) % MOD # Calculate k0 = (N!)^2 mod MOD k0 = pow(fact[N], 2, MOD) # Precompute g(k) = 1 / (k+1)^2 mod MOD for necessary range. # The maximum argument needed for g is sum of two distances. Max distance is N-1. # So max argument is 2*(N-1) = 2N-2. max_dist_arg = 2 * N - 2 # Array to store g[k] values g = [0] * (max_dist_arg + 1) # Need indices 0 to max_dist_arg # Calculate g(k) values only if N >= 1 (implies max_dist_arg >= -2) if max_dist_arg >= 0: # Calculate values (k+1)^2 mod MOD to be inverted vals_to_inv = [] for k in range(max_dist_arg + 1): # Calculate (k+1)^2 % MOD. Note k+1 ranges from 1 to 2N-1. val = pow(k + 1, 2, MOD) # Since MOD is prime and larger than 2N-1 for N<=200000, k+1 is never 0 mod MOD. # Thus (k+1)^2 is also never 0 mod MOD. vals_to_inv.append(val) # Use Batch modular inverse technique for efficiency if vals_to_inv: # Check if the list is not empty # Calculate prefix products prefix_prod = [1] * (len(vals_to_inv) + 1) for i in range(len(vals_to_inv)): prefix_prod[i+1] = (prefix_prod[i] * vals_to_inv[i]) % MOD # Calculate inverse of the product of all values using Fermat's Little Theorem inv_all = pow(prefix_prod[len(vals_to_inv)], MOD - 2, MOD) # Compute individual inverses using prefix products and the total inverse # Iterate backwards to compute inv[i] = prefix[i] * inv_all_suffix[i+1] for i in range(len(vals_to_inv) - 1, -1, -1): # g[k] stores inverse of (k+1)^2 mod MOD g[i] = (prefix_prod[i] * inv_all) % MOD # Update inv_all for the next iteration (effectively multiply by vals_to_inv[i]^{-1}) inv_all = (inv_all * vals_to_inv[i]) % MOD # NTT (Number Theoretic Transform) implementation # Using standard Cooley-Tukey Radix-2 FFT algorithm adapted for modular arithmetic primitive_root = 3 # Primitive root for MOD = 998244353 def ntt(a, inv): """ Performs NTT or inverse NTT on array a """ n = len(a) # Bit reversal permutation j = 0 for i in range(1, n): rev = n >> 1 while j >= rev: j -= rev; rev >>= 1 j += rev if i < j: a[i], a[j] = a[j], a[i] # Butterfly operations stages len_ = 2 # Current DFT size while len_ <= n: # Calculate primitive len_-th root of unity w_len = pow(primitive_root, (MOD - 1) // len_, MOD) if inv: # For inverse NTT, use inverse root w_len = pow(w_len, MOD - 2, MOD) # Iterate through blocks of size len_ for i in range(0, n, len_): w = 1 # Current power of the root of unity # Process pairs within the block for k in range(i, i + len_ // 2): u = a[k] idx_k_half = k + len_ // 2 v = (w * a[idx_k_half]) % MOD a[k] = (u + v) % MOD a[idx_k_half] = (u - v + MOD) % MOD # Ensure result is non-negative w = (w * w_len) % MOD # Update power for next pair len_ <<= 1 # Double DFT size for next stage # Scale by 1/n for inverse NTT if inv: n_inv = pow(n, MOD - 2, MOD) for i in range(n): a[i] = (a[i] * n_inv) % MOD return a def multiply(p1, p2): """ Multiplies two polynomials p1 and p2 using NTT """ s1 = len(p1); s2 = len(p2) # Determine required FFT size (power of 2 >= final degree + 1) n = 1 # Final polynomial degree is (s1-1) + (s2-1) = s1+s2-2. Length is s1+s2-1. while n < s1 + s2 - 1: n <<= 1 # Pad polynomials with zeros to size n p1_fft = p1 + [0] * (n - s1) p2_fft = p2 + [0] * (n - s2) # Transform polynomials to point-value representation ntt(p1_fft, False) ntt(p2_fft, False) # Pointwise multiplication in frequency domain res_fft = [(p1_fft[i] * p2_fft[i]) % MOD for i in range(n)] # Transform result back to coefficient representation ntt(res_fft, True) # Return the resulting polynomial coefficients, truncated to the actual size s1+s2-1 return res_fft[:s1 + s2 -1] # Precompute polynomial G(x) = sum g(j) x^j # Its coefficients are g[0], g[1], ..., g[max_dist_arg] G_poly = g[:max_dist_arg+1] # Array to store final answers E_p for each p ans = [0] * N # State variables for Centroid Decomposition active = [True] * N # Tracks nodes currently part of the component being processed subtree_sizes = [0] * N # Stores subtree sizes temporarily during centroid finding # --- Centroid Decomposition Helper Functions --- def get_subtree_sizes_cd(u, p): """ Computes subtree sizes rooted at u, considering only active nodes """ s = 1 for v in adj[u]: # Explore neighbor v if it's not the parent and is active if v != p and active[v]: s += get_subtree_sizes_cd(v, u) subtree_sizes[u] = s return s def find_centroid_cd(u, p, total_size): """ Finds the centroid of the component rooted at u """ for v in adj[u]: # Check if child v is active and not parent p if v != p and active[v]: # If size of subtree rooted at v is > half total size, centroid must be in that subtree if subtree_sizes[v] * 2 > total_size: return find_centroid_cd(v, u, total_size) # If no such child exists, u is the centroid return u def bfs_distances_cd(c): """ Performs BFS from centroid c to find distances within the current active component """ q = [(c, 0)]; visited = {c}; dist_from_c = {c: 0} component_nodes_data = [(c, 0)] # List of (node_idx, dist_from_c) max_D = 0 # Maximum distance found from centroid c head = 0 while head < len(q): curr, d = q[head]; head += 1 max_D = max(max_D, d) # Explore neighbors for neighbor in adj[curr]: # Check if neighbor is active and not visited yet if active[neighbor] and neighbor not in visited: visited.add(neighbor); dist_from_c[neighbor] = d + 1 q.append((neighbor, d+1)) component_nodes_data.append((neighbor, d+1)) return component_nodes_data, dist_from_c, max_D # --- Main Recursive Centroid Decomposition Function --- def cd_solve(nodes_indices): """ Recursive function to process a component defined by nodes_indices """ # Base case: Component is empty if not nodes_indices: return # Find an active node to start size computation and centroid finding start_node = -1 for idx in nodes_indices: if active[idx]: start_node = idx break # If no active node found (should not happen if nodes_indices is not empty), return. if start_node == -1: return # Compute sizes and find centroid c for the current component total_size = get_subtree_sizes_cd(start_node, -1) c = find_centroid_cd(start_node, -1, total_size) # Get distances and node data from centroid c via BFS component_nodes_data, dist_from_c, max_D = bfs_distances_cd(c) # Compute S_k = sum_{i in component, dist(c,i)=k} Q[i] S_k = [0] * (max_D + 1) for node_idx, d in component_nodes_data: S_k[d] = (S_k[d] + Q[node_idx]) % MOD # Construct polynomial P(x) = sum S_k x^{max_D - k} # Coefficients are [S_maxD, S_{maxD-1}, ..., S_0] P_coeffs = [0] * (max_D + 1) for k in range(max_D + 1): P_coeffs[k] = S_k[max_D - k] # Compute C(x) = P(x) * G(x) using NTT based polynomial multiplication C_coeffs = multiply(P_coeffs, G_poly) # Calculate the contribution for the centroid node c itself # This is Sum_{i in component} k0 * Q[i] * g(dist(c,i)) contrib_c = 0 for node_idx, d in component_nodes_data: contrib_c = (contrib_c + Q[node_idx] * g[d]) % MOD # Add this contribution to the final answer for node c ans[c] = (ans[c] + k0 * contrib_c) % MOD # Deactivate centroid c for processing subproblems active[c] = False # Iterate through neighbors of c to identify subtrees/subcomponents for neighbor in adj[c]: if active[neighbor]: # Check if this neighbor leads to an active part # Collect all nodes in the subtree component starting at neighbor # Use BFS starting from neighbor, restricted to active nodes subtree_nodes = [] q_subtree = [neighbor]; visited_subtree = {neighbor} head_subtree = 0 while head_subtree < len(q_subtree): curr_sub = q_subtree[head_subtree]; subtree_nodes.append(curr_sub); head_subtree += 1 for next_node in adj[curr_sub]: # Explore only active nodes not visited in this specific BFS if active[next_node] and next_node not in visited_subtree: visited_subtree.add(next_node); q_subtree.append(next_node) # If the component connected through neighbor is empty, continue if not subtree_nodes: continue # Compute S_k(v) = sum_{i in Tv, dist(c,i)=k} Q[i] for this subtree Tv S_k_v = [0] * (max_D + 1) for node_idx in subtree_nodes: # Check if node_idx was reached from c (should always be true if active) if node_idx in dist_from_c: d = dist_from_c[node_idx] # Ensure distance is within bounds if d < len(S_k_v): S_k_v[d] = (S_k_v[d] + Q[node_idx]) % MOD # Construct polynomial Pv(x) = sum Sk(v) x^{max_D - k} for the subtree P_v_coeffs = [0] * (max_D + 1) for k in range(max_D + 1): P_v_coeffs[k] = S_k_v[max_D - k] # Compute Cv(x) = Pv(x) * G(x) using NTT C_v_coeffs = multiply(P_v_coeffs, G_poly) # Calculate and add contributions for nodes p in this subtree Tv for p in subtree_nodes: # Check if node p was reached from c if p in dist_from_c: dp = dist_from_c[p] # Distance from p to centroid c idx = dp + max_D # Required coefficient index in C(x) and Cv(x) # Safely get coefficients from result polynomials term1 = C_coeffs[idx] if idx < len(C_coeffs) else 0 term2 = C_v_coeffs[idx] if idx < len(C_v_coeffs) else 0 # Calculate contribution Delta E_p = k0 * (term1 - term2) delta_Ep = (term1 - term2 + MOD) % MOD # Add to the final answer for node p ans[p] = (ans[p] + k0 * delta_Ep) % MOD # Recursively call centroid decomposition on this subtree component cd_solve(subtree_nodes) # Initial call to start the Centroid Decomposition process on the whole tree initial_nodes = list(range(N)) cd_solve(initial_nodes) # Output the final computed E_p values for all p for res in ans: print(res) # Run the solution function solve()