結果
| 問題 |
No.439 チワワのなる木
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 20:51:46 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 355 ms / 5,000 ms |
| コード長 | 2,538 bytes |
| コンパイル時間 | 169 ms |
| コンパイル使用メモリ | 82,172 KB |
| 実行使用メモリ | 112,996 KB |
| 最終ジャッジ日時 | 2025-03-20 20:52:19 |
| 合計ジャッジ時間 | 5,478 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 28 |
ソースコード
import sys
from collections import defaultdict, deque
def main():
sys.setrecursionlimit(1 << 25)
n = int(sys.stdin.readline())
s = sys.stdin.readline().strip()
edges = [[] for _ in range(n+1)] # 1-based indexing
for _ in range(n-1):
a, b = map(int, sys.stdin.readline().split())
edges[a].append(b)
edges[b].append(a)
# Build the tree with root 1, compute parent and children relationships
parent = [0]*(n+1)
children = [[] for _ in range(n+1)]
visited = [False]*(n+1)
q = deque()
root = 1
q.append(root)
visited[root] = True
while q:
u = q.popleft()
for v in edges[u]:
if not visited[v] and v != parent[u]:
parent[v] = u
children[u].append(v)
visited[v] = True
q.append(v)
# Compute c_subtree and w_subtree for each node using post-order traversal
c_subtree = [0]*(n+1)
w_subtree = [0]*(n+1)
stack = [(root, False)]
while stack:
u, processed = stack.pop()
if processed:
# Calculate c_subtree and w_subtree
if s[u-1] == 'c':
c_subtree[u] = 1
else:
c_subtree[u] = 0
if s[u-1] == 'w':
w_subtree[u] = 1
else:
w_subtree[u] = 0
for v in children[u]:
c_subtree[u] += c_subtree[v]
w_subtree[u] += w_subtree[v]
else:
stack.append((u, True))
for v in reversed(children[u]):
stack.append((v, False))
total_c = sum(1 for ch in s if ch == 'c')
total_w = sum(1 for ch in s if ch == 'w')
result = 0
for j in range(1, n+1):
if s[j-1] != 'w':
continue
# Case 1: C_parent * total_W_children
c_parent = total_c - c_subtree[j]
total_W_children = sum(w_subtree[v] for v in children[j])
case1 = c_parent * total_W_children
# Case 2: sum of c_child * (total_W_children - w_child)
sum_case2 = 0
for v in children[j]:
sum_case2 += c_subtree[v] * (total_W_children - w_subtree[v])
# Case 3: C_children_sum * W_parent
c_children_sum = sum(c_subtree[v] for v in children[j])
w_parent = total_w - w_subtree[j]
case3 = c_children_sum * w_parent
result += case1 + sum_case2 + case3
print(result)
if __name__ == "__main__":
main()
lam6er