結果
問題 | No.439 チワワのなる木 |
ユーザー | terasa |
提出日時 | 2022-11-03 21:41:29 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 578 ms / 5,000 ms |
コード長 | 3,620 bytes |
コンパイル時間 | 175 ms |
コンパイル使用メモリ | 82,100 KB |
実行使用メモリ | 145,744 KB |
最終ジャッジ日時 | 2024-07-18 04:24:14 |
合計ジャッジ時間 | 7,357 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 59 ms
69,364 KB |
testcase_01 | AC | 55 ms
69,992 KB |
testcase_02 | AC | 60 ms
70,208 KB |
testcase_03 | AC | 56 ms
69,552 KB |
testcase_04 | AC | 56 ms
68,920 KB |
testcase_05 | AC | 56 ms
69,052 KB |
testcase_06 | AC | 55 ms
69,608 KB |
testcase_07 | AC | 55 ms
69,388 KB |
testcase_08 | AC | 57 ms
69,756 KB |
testcase_09 | AC | 56 ms
70,496 KB |
testcase_10 | AC | 58 ms
70,452 KB |
testcase_11 | AC | 60 ms
68,880 KB |
testcase_12 | AC | 61 ms
70,080 KB |
testcase_13 | AC | 67 ms
72,432 KB |
testcase_14 | AC | 79 ms
77,656 KB |
testcase_15 | AC | 120 ms
80,136 KB |
testcase_16 | AC | 137 ms
80,780 KB |
testcase_17 | AC | 130 ms
80,228 KB |
testcase_18 | AC | 401 ms
112,364 KB |
testcase_19 | AC | 384 ms
110,236 KB |
testcase_20 | AC | 555 ms
122,556 KB |
testcase_21 | AC | 233 ms
93,240 KB |
testcase_22 | AC | 225 ms
91,220 KB |
testcase_23 | AC | 578 ms
125,972 KB |
testcase_24 | AC | 565 ms
129,312 KB |
testcase_25 | AC | 525 ms
136,688 KB |
testcase_26 | AC | 441 ms
145,744 KB |
testcase_27 | AC | 281 ms
128,456 KB |
ソースコード
from typing import List, Tuple, Callable, TypeVar from typing import List, Tuple, Optional import sys import itertools import heapq import bisect from collections import deque, defaultdict from functools import lru_cache, cmp_to_key input = sys.stdin.readline # for AtCoder Easy test if __file__ != 'prog.py': sys.setrecursionlimit(10 ** 6) def readints(): return map(int, input().split()) def readlist(): return list(readints()) def readstr(): return input().rstrip() T = TypeVar('T') class Rerooting: # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151 # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて # 下記のように表すことができる # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v) def __init__(self, N: int, E: List[Tuple[int, int]], f: Callable[[T, int, int, int], T], g: Callable[[T, int], T], merge: Callable[[T, T], T], e: T): self.N = N self.E = E self.f = f self.g = g self.merge = merge self.e = e self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)] self._calculate() def _dfs1(self, root): stack = [(root, -1)] ret = [self.e] * self.N while stack: v, p = stack.pop() if v < 0: v = ~v acc = self.e for i, (c, d) in enumerate(self.E[v]): if d == p: continue self.dp[v][i] = ret[d] acc = self.merge(acc, self.f(ret[d], v, d, c)) ret[v] = self.g(acc, v) continue stack.append((~v, p)) for i, (c, d) in enumerate(self.E[v]): if d == p: continue stack.append((d, v)) def _dfs2(self, root): stack = [(root, -1, self.e)] while stack: v, p, from_par = stack.pop() for i, (c, d) in enumerate(self.E[v]): if d == p: self.dp[v][i] = from_par break ch = len(self.E[v]) Sr = [self.e] * (ch + 1) for i in range(ch, 0, -1): c, d = self.E[v][i - 1] Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c)) Sl = self.e for i, (c, d) in enumerate(self.E[v]): if d != p: val = self.merge(Sl, Sr[i + 1]) stack.append((d, v, self.g(val, v))) Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c)) def _calculate(self, root=0): self._dfs1(root) self._dfs2(root) def solve(self, v): ans = self.e for i, (c, d) in enumerate(self.E[v]): ans = self.merge(ans, self.f(self.dp[v][i], v, d, c)) return self.g(ans, v) N = int(input()) S = readstr() E = [[] for _ in range(N)] for _ in range(N - 1): a, b = readints() a -= 1 b -= 1 E[a].append((1, b)) E[b].append((1, a)) def f(a, v, ch, cost): return (a[0], a[1], a[0] * a[1]) def g(a, v): if S[v] == 'c': return (a[0] + 1, a[1], a[2]) else: return (a[0], a[1] + 1, a[2]) def merge(a, b): return (a[0] + b[0], a[1] + b[1], a[2] + b[2]) solver = Rerooting(N, E, f, g, merge, (0, 0, 0)) ans = 0 for i in range(N): if S[i] != 'w': continue c, w, d = solver.solve(i) ans += c * (w - 1) - d print(ans)