結果
| 問題 |
No.3350 Tree and Two Apples
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-11-12 09:11:43 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 913 ms / 2,000 ms |
| コード長 | 3,193 bytes |
| コンパイル時間 | 332 ms |
| コンパイル使用メモリ | 82,600 KB |
| 実行使用メモリ | 137,752 KB |
| 最終ジャッジ日時 | 2025-11-13 21:21:51 |
| 合計ジャッジ時間 | 28,954 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 35 |
ソースコード
import sys
from collections import deque, defaultdict
from itertools import groupby
from random import randint
M = 10 ** 18
pool = {}
lcg_gen = None
def f(e):
if e not in pool:
pool[e] = (randint(0, M - 1), randint(0, M - 1))
return pool[e]
else:
return pool[e]
def solve(N, E):
par = [-1] * N
order = []
q = deque([0])
par[0] = -2
head = 0
while head < len(order) + 1:
if head == len(order):
if not q:
break
i = q.popleft()
order.append(i)
else:
i = order[head]
for j in E[i]:
if j == par[i]:
continue
par[j] = i
q.append(j)
head += 1
up = [(0, 0)] * N
up2 = [0] * N
for i in reversed(order):
children_data = []
hash_sum = (0, 0)
for j in E[i]:
if j == par[i]:
continue
hash_sum = ((hash_sum[0] + up[j][0]) % M, (hash_sum[1] + up[j][1]) % M)
children_data.append((up[j], up2[j]))
up[i] = f(hash_sum)
children_data.sort()
unique_children_data = [k for k, g in groupby(children_data)]
up2[i] = sum(e[1] for e in unique_children_data) + 1
down = [(0, 0)] * N
down2 = [0] * N
dp2 = [0] * N
for i in order:
mp = defaultdict(int)
child_indices = []
children_data_map = {}
if par[i] != -2:
mp[(down[i], down2[i])] += 1
for j in E[i]:
if j == par[i]:
continue
child_data = (up[j], up2[j])
mp[child_data] += 1
child_indices.append(j)
children_data_map[j] = child_data
sum_dp2 = sum(data[1] for data in mp) + 1
dp2[i] = sum_dp2
pre = down[i]
for j in child_indices:
down[j] = pre
h_j = children_data_map[j][0]
pre = ((pre[0] + h_j[0]) % M, (pre[1] + h_j[1]) % M)
suf = (0, 0)
for j in reversed(child_indices):
down[j] = ((down[j][0] + suf[0]) % M, (down[j][1] + suf[1]) % M)
down[j] = f(down[j])
child_data = children_data_map[j]
if mp[child_data] == 1:
down2[j] = sum_dp2 - child_data[1]
else:
down2[j] = sum_dp2
h_j = children_data_map[j][0]
suf = ((suf[0] + h_j[0]) % M, (suf[1] + h_j[1]) % M)
final_set = {}
for i in range(N):
tmp = (0, 0)
if par[i] != -2:
tmp = ((tmp[0] + down[i][0]) % M, (tmp[1] + down[i][1]) % M)
for j in E[i]:
if j == par[i]:
continue
tmp = ((tmp[0] + up[j][0]) % M, (tmp[1] + up[j][1]) % M)
final_set[tmp] = dp2[i]
ans = sum(final_set.values())
return ans
N = int(input())
E = [[] for _ in range(N)]
for _ in range(N - 1):
u, v = map(int, input().split())
E[u - 1].append(v - 1)
E[v - 1].append(u - 1)
print(solve(N, E))