結果
問題 | No.439 チワワのなる木 |
ユーザー | terasa |
提出日時 | 2022-11-03 21:41:29 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 758 ms / 5,000 ms |
コード長 | 3,620 bytes |
コンパイル時間 | 280 ms |
コンパイル使用メモリ | 87,196 KB |
実行使用メモリ | 150,508 KB |
最終ジャッジ日時 | 2023-09-25 05:10:40 |
合計ジャッジ時間 | 11,125 ms |
ジャッジサーバーID (参考情報) |
judge14 / judge13 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 165 ms
80,416 KB |
testcase_01 | AC | 167 ms
80,292 KB |
testcase_02 | AC | 161 ms
80,332 KB |
testcase_03 | AC | 161 ms
80,332 KB |
testcase_04 | AC | 161 ms
80,180 KB |
testcase_05 | AC | 164 ms
80,272 KB |
testcase_06 | AC | 162 ms
80,264 KB |
testcase_07 | AC | 162 ms
80,288 KB |
testcase_08 | AC | 164 ms
80,188 KB |
testcase_09 | AC | 164 ms
80,156 KB |
testcase_10 | AC | 167 ms
80,308 KB |
testcase_11 | AC | 165 ms
80,436 KB |
testcase_12 | AC | 164 ms
80,320 KB |
testcase_13 | AC | 172 ms
80,964 KB |
testcase_14 | AC | 177 ms
80,928 KB |
testcase_15 | AC | 271 ms
84,132 KB |
testcase_16 | AC | 262 ms
85,032 KB |
testcase_17 | AC | 243 ms
84,844 KB |
testcase_18 | AC | 583 ms
115,964 KB |
testcase_19 | AC | 546 ms
114,492 KB |
testcase_20 | AC | 689 ms
126,576 KB |
testcase_21 | AC | 340 ms
96,364 KB |
testcase_22 | AC | 345 ms
95,060 KB |
testcase_23 | AC | 758 ms
129,328 KB |
testcase_24 | AC | 663 ms
131,364 KB |
testcase_25 | AC | 580 ms
140,684 KB |
testcase_26 | AC | 580 ms
150,508 KB |
testcase_27 | AC | 421 ms
132,528 KB |
ソースコード
from typing import List, Tuple, Callable, TypeVar from typing import List, Tuple, Optional import sys import itertools import heapq import bisect from collections import deque, defaultdict from functools import lru_cache, cmp_to_key input = sys.stdin.readline # for AtCoder Easy test if __file__ != 'prog.py': sys.setrecursionlimit(10 ** 6) def readints(): return map(int, input().split()) def readlist(): return list(readints()) def readstr(): return input().rstrip() T = TypeVar('T') class Rerooting: # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151 # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて # 下記のように表すことができる # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v) def __init__(self, N: int, E: List[Tuple[int, int]], f: Callable[[T, int, int, int], T], g: Callable[[T, int], T], merge: Callable[[T, T], T], e: T): self.N = N self.E = E self.f = f self.g = g self.merge = merge self.e = e self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)] self._calculate() def _dfs1(self, root): stack = [(root, -1)] ret = [self.e] * self.N while stack: v, p = stack.pop() if v < 0: v = ~v acc = self.e for i, (c, d) in enumerate(self.E[v]): if d == p: continue self.dp[v][i] = ret[d] acc = self.merge(acc, self.f(ret[d], v, d, c)) ret[v] = self.g(acc, v) continue stack.append((~v, p)) for i, (c, d) in enumerate(self.E[v]): if d == p: continue stack.append((d, v)) def _dfs2(self, root): stack = [(root, -1, self.e)] while stack: v, p, from_par = stack.pop() for i, (c, d) in enumerate(self.E[v]): if d == p: self.dp[v][i] = from_par break ch = len(self.E[v]) Sr = [self.e] * (ch + 1) for i in range(ch, 0, -1): c, d = self.E[v][i - 1] Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c)) Sl = self.e for i, (c, d) in enumerate(self.E[v]): if d != p: val = self.merge(Sl, Sr[i + 1]) stack.append((d, v, self.g(val, v))) Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c)) def _calculate(self, root=0): self._dfs1(root) self._dfs2(root) def solve(self, v): ans = self.e for i, (c, d) in enumerate(self.E[v]): ans = self.merge(ans, self.f(self.dp[v][i], v, d, c)) return self.g(ans, v) N = int(input()) S = readstr() E = [[] for _ in range(N)] for _ in range(N - 1): a, b = readints() a -= 1 b -= 1 E[a].append((1, b)) E[b].append((1, a)) def f(a, v, ch, cost): return (a[0], a[1], a[0] * a[1]) def g(a, v): if S[v] == 'c': return (a[0] + 1, a[1], a[2]) else: return (a[0], a[1] + 1, a[2]) def merge(a, b): return (a[0] + b[0], a[1] + b[1], a[2] + b[2]) solver = Rerooting(N, E, f, g, merge, (0, 0, 0)) ans = 0 for i in range(N): if S[i] != 'w': continue c, w, d = solver.solve(i) ans += c * (w - 1) - d print(ans)