結果

問題 No.1796 木上のクーロン
ユーザー qwewe
提出日時 2025-05-14 13:19:17
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 13,105 bytes
コンパイル時間 309 ms
コンパイル使用メモリ 82,796 KB
実行使用メモリ 278,700 KB
最終ジャッジ日時 2025-05-14 13:20:32
合計ジャッジ時間 27,458 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 17 TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

# Set higher recursion depth limit for deep trees or complex centroid decompositions.
# The limit needs to be sufficiently large for N up to 200000.
try:
    # Set recursion limit slightly above max N to accommodate call stack depth.
    sys.setrecursionlimit(200000 + 10) 
except Exception: 
    # If setting recursion limit fails (e.g., due to OS restrictions),
    # the code might fail on deep trees. Standard Python limits are often around 1000.
    pass

def solve():
    N = int(sys.stdin.readline())
    Q = list(map(int, sys.stdin.readline().split()))
    adj = [[] for _ in range(N)]
    for _ in range(N - 1):
        # Read edges and build adjacency list (0-based indexing)
        u, v = map(int, sys.stdin.readline().split())
        adj[u - 1].append(v - 1)
        adj[v - 1].append(u - 1)

    MOD = 998244353

    # Precompute factorials modulo MOD
    fact = [1] * (N + 1)
    for i in range(1, N + 1):
        fact[i] = (fact[i - 1] * i) % MOD
    
    # Calculate k0 = (N!)^2 mod MOD
    k0 = pow(fact[N], 2, MOD)

    # Precompute g(k) = 1 / (k+1)^2 mod MOD for necessary range.
    # The maximum argument needed for g is sum of two distances. Max distance is N-1.
    # So max argument is 2*(N-1) = 2N-2.
    max_dist_arg = 2 * N - 2 
    
    # Array to store g[k] values
    g = [0] * (max_dist_arg + 1) # Need indices 0 to max_dist_arg

    # Calculate g(k) values only if N >= 1 (implies max_dist_arg >= -2)
    if max_dist_arg >= 0:
        # Calculate values (k+1)^2 mod MOD to be inverted
        vals_to_inv = []
        for k in range(max_dist_arg + 1):
             # Calculate (k+1)^2 % MOD. Note k+1 ranges from 1 to 2N-1.
             val = pow(k + 1, 2, MOD)
             # Since MOD is prime and larger than 2N-1 for N<=200000, k+1 is never 0 mod MOD.
             # Thus (k+1)^2 is also never 0 mod MOD.
             vals_to_inv.append(val)

        # Use Batch modular inverse technique for efficiency
        if vals_to_inv: # Check if the list is not empty
            # Calculate prefix products
            prefix_prod = [1] * (len(vals_to_inv) + 1)
            for i in range(len(vals_to_inv)):
                prefix_prod[i+1] = (prefix_prod[i] * vals_to_inv[i]) % MOD
            
            # Calculate inverse of the product of all values using Fermat's Little Theorem
            inv_all = pow(prefix_prod[len(vals_to_inv)], MOD - 2, MOD)

            # Compute individual inverses using prefix products and the total inverse
            # Iterate backwards to compute inv[i] = prefix[i] * inv_all_suffix[i+1]
            for i in range(len(vals_to_inv) - 1, -1, -1):
                 # g[k] stores inverse of (k+1)^2 mod MOD
                 g[i] = (prefix_prod[i] * inv_all) % MOD
                 # Update inv_all for the next iteration (effectively multiply by vals_to_inv[i]^{-1})
                 inv_all = (inv_all * vals_to_inv[i]) % MOD

    # NTT (Number Theoretic Transform) implementation
    # Using standard Cooley-Tukey Radix-2 FFT algorithm adapted for modular arithmetic
    primitive_root = 3 # Primitive root for MOD = 998244353
    
    def ntt(a, inv):
        """ Performs NTT or inverse NTT on array a """
        n = len(a)
        # Bit reversal permutation
        j = 0
        for i in range(1, n):
            rev = n >> 1
            while j >= rev: j -= rev; rev >>= 1
            j += rev
            if i < j: a[i], a[j] = a[j], a[i]

        # Butterfly operations stages
        len_ = 2 # Current DFT size
        while len_ <= n:
            # Calculate primitive len_-th root of unity
            w_len = pow(primitive_root, (MOD - 1) // len_, MOD)
            if inv: # For inverse NTT, use inverse root
                 w_len = pow(w_len, MOD - 2, MOD)
            
            # Iterate through blocks of size len_
            for i in range(0, n, len_):
                w = 1 # Current power of the root of unity
                # Process pairs within the block
                for k in range(i, i + len_ // 2):
                    u = a[k]
                    idx_k_half = k + len_ // 2
                    v = (w * a[idx_k_half]) % MOD
                    a[k] = (u + v) % MOD
                    a[idx_k_half] = (u - v + MOD) % MOD # Ensure result is non-negative
                    w = (w * w_len) % MOD # Update power for next pair
            len_ <<= 1 # Double DFT size for next stage

        # Scale by 1/n for inverse NTT
        if inv:
            n_inv = pow(n, MOD - 2, MOD)
            for i in range(n): a[i] = (a[i] * n_inv) % MOD
        return a

    def multiply(p1, p2):
        """ Multiplies two polynomials p1 and p2 using NTT """
        s1 = len(p1); s2 = len(p2)
        # Determine required FFT size (power of 2 >= final degree + 1)
        n = 1
        # Final polynomial degree is (s1-1) + (s2-1) = s1+s2-2. Length is s1+s2-1.
        while n < s1 + s2 - 1: n <<= 1
        
        # Pad polynomials with zeros to size n
        p1_fft = p1 + [0] * (n - s1)
        p2_fft = p2 + [0] * (n - s2)

        # Transform polynomials to point-value representation
        ntt(p1_fft, False)
        ntt(p2_fft, False)

        # Pointwise multiplication in frequency domain
        res_fft = [(p1_fft[i] * p2_fft[i]) % MOD for i in range(n)]
        
        # Transform result back to coefficient representation
        ntt(res_fft, True)
        # Return the resulting polynomial coefficients, truncated to the actual size s1+s2-1
        return res_fft[:s1 + s2 -1]

    # Precompute polynomial G(x) = sum g(j) x^j
    # Its coefficients are g[0], g[1], ..., g[max_dist_arg]
    G_poly = g[:max_dist_arg+1] 
    
    # Array to store final answers E_p for each p
    ans = [0] * N
    
    # State variables for Centroid Decomposition
    active = [True] * N # Tracks nodes currently part of the component being processed
    subtree_sizes = [0] * N # Stores subtree sizes temporarily during centroid finding

    # --- Centroid Decomposition Helper Functions ---
    
    def get_subtree_sizes_cd(u, p):
        """ Computes subtree sizes rooted at u, considering only active nodes """
        s = 1
        for v in adj[u]:
            # Explore neighbor v if it's not the parent and is active
            if v != p and active[v]:
                s += get_subtree_sizes_cd(v, u)
        subtree_sizes[u] = s
        return s

    def find_centroid_cd(u, p, total_size):
        """ Finds the centroid of the component rooted at u """
        for v in adj[u]:
             # Check if child v is active and not parent p
            if v != p and active[v]:
                # If size of subtree rooted at v is > half total size, centroid must be in that subtree
                if subtree_sizes[v] * 2 > total_size:
                     return find_centroid_cd(v, u, total_size)
        # If no such child exists, u is the centroid
        return u

    def bfs_distances_cd(c):
        """ Performs BFS from centroid c to find distances within the current active component """
        q = [(c, 0)]; visited = {c}; dist_from_c = {c: 0}
        component_nodes_data = [(c, 0)] # List of (node_idx, dist_from_c)
        max_D = 0 # Maximum distance found from centroid c
         
        head = 0
        while head < len(q):
            curr, d = q[head]; head += 1
            max_D = max(max_D, d)
            # Explore neighbors
            for neighbor in adj[curr]:
                # Check if neighbor is active and not visited yet
                if active[neighbor] and neighbor not in visited: 
                    visited.add(neighbor); dist_from_c[neighbor] = d + 1
                    q.append((neighbor, d+1))
                    component_nodes_data.append((neighbor, d+1))
        return component_nodes_data, dist_from_c, max_D

    # --- Main Recursive Centroid Decomposition Function ---
    def cd_solve(nodes_indices):
        """ Recursive function to process a component defined by nodes_indices """
        
        # Base case: Component is empty
        if not nodes_indices: return

        # Find an active node to start size computation and centroid finding
        start_node = -1
        for idx in nodes_indices:
            if active[idx]:
                start_node = idx
                break
        # If no active node found (should not happen if nodes_indices is not empty), return.
        if start_node == -1: return 

        # Compute sizes and find centroid c for the current component
        total_size = get_subtree_sizes_cd(start_node, -1)
        c = find_centroid_cd(start_node, -1, total_size)
        
        # Get distances and node data from centroid c via BFS
        component_nodes_data, dist_from_c, max_D = bfs_distances_cd(c)
        
        # Compute S_k = sum_{i in component, dist(c,i)=k} Q[i]
        S_k = [0] * (max_D + 1)
        for node_idx, d in component_nodes_data: S_k[d] = (S_k[d] + Q[node_idx]) % MOD

        # Construct polynomial P(x) = sum S_k x^{max_D - k}
        # Coefficients are [S_maxD, S_{maxD-1}, ..., S_0]
        P_coeffs = [0] * (max_D + 1)
        for k in range(max_D + 1): P_coeffs[k] = S_k[max_D - k] 
        
        # Compute C(x) = P(x) * G(x) using NTT based polynomial multiplication
        C_coeffs = multiply(P_coeffs, G_poly)
        
        # Calculate the contribution for the centroid node c itself
        # This is Sum_{i in component} k0 * Q[i] * g(dist(c,i))
        contrib_c = 0
        for node_idx, d in component_nodes_data: contrib_c = (contrib_c + Q[node_idx] * g[d]) % MOD
        # Add this contribution to the final answer for node c
        ans[c] = (ans[c] + k0 * contrib_c) % MOD

        # Deactivate centroid c for processing subproblems
        active[c] = False 

        # Iterate through neighbors of c to identify subtrees/subcomponents
        for neighbor in adj[c]:
            if active[neighbor]: # Check if this neighbor leads to an active part
                
                # Collect all nodes in the subtree component starting at neighbor
                # Use BFS starting from neighbor, restricted to active nodes
                subtree_nodes = [] 
                q_subtree = [neighbor]; visited_subtree = {neighbor}
                
                head_subtree = 0
                while head_subtree < len(q_subtree):
                    curr_sub = q_subtree[head_subtree]; subtree_nodes.append(curr_sub); head_subtree += 1
                    for next_node in adj[curr_sub]:
                         # Explore only active nodes not visited in this specific BFS
                         if active[next_node] and next_node not in visited_subtree:
                             visited_subtree.add(next_node); q_subtree.append(next_node)

                # If the component connected through neighbor is empty, continue
                if not subtree_nodes: continue 

                # Compute S_k(v) = sum_{i in Tv, dist(c,i)=k} Q[i] for this subtree Tv
                S_k_v = [0] * (max_D + 1)
                for node_idx in subtree_nodes:
                    # Check if node_idx was reached from c (should always be true if active)
                    if node_idx in dist_from_c: 
                         d = dist_from_c[node_idx] 
                         # Ensure distance is within bounds
                         if d < len(S_k_v): S_k_v[d] = (S_k_v[d] + Q[node_idx]) % MOD
                
                # Construct polynomial Pv(x) = sum Sk(v) x^{max_D - k} for the subtree
                P_v_coeffs = [0] * (max_D + 1)
                for k in range(max_D + 1): P_v_coeffs[k] = S_k_v[max_D - k]

                # Compute Cv(x) = Pv(x) * G(x) using NTT
                C_v_coeffs = multiply(P_v_coeffs, G_poly)
                
                # Calculate and add contributions for nodes p in this subtree Tv
                for p in subtree_nodes:
                     # Check if node p was reached from c
                     if p in dist_from_c: 
                         dp = dist_from_c[p] # Distance from p to centroid c
                         idx = dp + max_D # Required coefficient index in C(x) and Cv(x)
                         
                         # Safely get coefficients from result polynomials
                         term1 = C_coeffs[idx] if idx < len(C_coeffs) else 0
                         term2 = C_v_coeffs[idx] if idx < len(C_v_coeffs) else 0
                         
                         # Calculate contribution Delta E_p = k0 * (term1 - term2)
                         delta_Ep = (term1 - term2 + MOD) % MOD
                         # Add to the final answer for node p
                         ans[p] = (ans[p] + k0 * delta_Ep) % MOD
                
                # Recursively call centroid decomposition on this subtree component
                cd_solve(subtree_nodes)
                
    # Initial call to start the Centroid Decomposition process on the whole tree
    initial_nodes = list(range(N))
    cd_solve(initial_nodes)

    # Output the final computed E_p values for all p
    for res in ans: 
        print(res)

# Run the solution function
solve()
0