結果

問題 No.439 チワワのなる木
ユーザー lam6er
提出日時 2025-03-20 20:51:46
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 355 ms / 5,000 ms
コード長 2,538 bytes
コンパイル時間 169 ms
コンパイル使用メモリ 82,172 KB
実行使用メモリ 112,996 KB
最終ジャッジ日時 2025-03-20 20:52:19
合計ジャッジ時間 5,478 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 28
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict, deque

def main():
    sys.setrecursionlimit(1 << 25)
    n = int(sys.stdin.readline())
    s = sys.stdin.readline().strip()
    edges = [[] for _ in range(n+1)]  # 1-based indexing
    for _ in range(n-1):
        a, b = map(int, sys.stdin.readline().split())
        edges[a].append(b)
        edges[b].append(a)
    
    # Build the tree with root 1, compute parent and children relationships
    parent = [0]*(n+1)
    children = [[] for _ in range(n+1)]
    visited = [False]*(n+1)
    q = deque()
    root = 1
    q.append(root)
    visited[root] = True
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v] and v != parent[u]:
                parent[v] = u
                children[u].append(v)
                visited[v] = True
                q.append(v)
    
    # Compute c_subtree and w_subtree for each node using post-order traversal
    c_subtree = [0]*(n+1)
    w_subtree = [0]*(n+1)
    stack = [(root, False)]
    while stack:
        u, processed = stack.pop()
        if processed:
            # Calculate c_subtree and w_subtree
            if s[u-1] == 'c':
                c_subtree[u] = 1
            else:
                c_subtree[u] = 0
            if s[u-1] == 'w':
                w_subtree[u] = 1
            else:
                w_subtree[u] = 0
            for v in children[u]:
                c_subtree[u] += c_subtree[v]
                w_subtree[u] += w_subtree[v]
        else:
            stack.append((u, True))
            for v in reversed(children[u]):
                stack.append((v, False))
    
    total_c = sum(1 for ch in s if ch == 'c')
    total_w = sum(1 for ch in s if ch == 'w')
    
    result = 0
    for j in range(1, n+1):
        if s[j-1] != 'w':
            continue
        
        # Case 1: C_parent * total_W_children
        c_parent = total_c - c_subtree[j]
        total_W_children = sum(w_subtree[v] for v in children[j])
        case1 = c_parent * total_W_children
        
        # Case 2: sum of c_child * (total_W_children - w_child)
        sum_case2 = 0
        for v in children[j]:
            sum_case2 += c_subtree[v] * (total_W_children - w_subtree[v])
        
        # Case 3: C_children_sum * W_parent
        c_children_sum = sum(c_subtree[v] for v in children[j])
        w_parent = total_w - w_subtree[j]
        case3 = c_children_sum * w_parent
        
        result += case1 + sum_case2 + case3
    
    print(result)

if __name__ == "__main__":
    main()
0