結果

問題 No.1333 Squared Sum
ユーザー lam6er
提出日時 2025-03-20 20:56:56
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 725 ms / 2,000 ms
コード長 2,049 bytes
コンパイル時間 282 ms
コンパイル使用メモリ 82,684 KB
実行使用メモリ 169,364 KB
最終ジャッジ日時 2025-03-20 20:57:24
合計ジャッジ時間 18,880 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 44
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

MOD = 10**9 + 7

def main():
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    adj = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        u, v, w = map(int, sys.stdin.readline().split())
        adj[u].append((v, w))
        adj[v].append((u, w))
    
    # Build children list with BFS
    root = 1
    children = [[] for _ in range(N + 1)]
    visited = [False] * (N + 1)
    q = deque()
    q.append(root)
    visited[root] = True
    while q:
        u = q.popleft()
        for v, w in adj[u]:
            if not visited[v]:
                visited[v] = True
                children[u].append((v, w))
                q.append(v)
    
    # Post-order traversal
    post_order = []
    stack = [(root, False)]
    while stack:
        node, visited_flag = stack.pop()
        if visited_flag:
            post_order.append(node)
        else:
            stack.append((node, True))
            # Push children in reverse order to process them in order
            for child, w in reversed(children[node]):
                stack.append((child, False))
    
    s0 = [0] * (N + 1)
    s1 = [0] * (N + 1)
    s2 = [0] * (N + 1)
    ans = 0
    
    for u in post_order:
        s0[u] = 1
        s1[u] = 0
        s2[u] = 0
        for v, w in children[u]:
            # Process child v
            s0_child = s0[v]
            s1_child = (s1[v] + w * s0_child) % MOD
            s2_child = (s2[v] + 2 * w * s1[v] + (w * w) % MOD * s0_child) % MOD
            
            # Calculate contribution
            contrib = (s2[u] * s0_child) % MOD
            contrib = (contrib + (s2_child * s0[u]) % MOD) % MOD
            contrib = (contrib + (2 * ((s1[u] * s1_child) % MOD)) % MOD) % MOD
            ans = (ans + contrib) % MOD
            
            # Merge into parent
            s0[u] = (s0[u] + s0_child) % MOD
            s1[u] = (s1[u] + s1_child) % MOD
            s2[u] = (s2[u] + s2_child) % MOD
    
    print(ans % MOD)

if __name__ == '__main__':
    main()
0