結果

問題 No.3350 Tree and Two Apples
コンテスト
ユーザー tassei903
提出日時 2025-11-12 09:11:43
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 913 ms / 2,000 ms
コード長 3,193 bytes
コンパイル時間 332 ms
コンパイル使用メモリ 82,600 KB
実行使用メモリ 137,752 KB
最終ジャッジ日時 2025-11-13 21:21:51
合計ジャッジ時間 28,954 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 35
権限があれば一括ダウンロードができます

ソースコード

diff #






import sys
from collections import deque, defaultdict
from itertools import groupby
from random import randint
M = 10 ** 18

pool = {}
lcg_gen = None

def f(e):
    if e not in pool:
        pool[e] = (randint(0, M - 1), randint(0, M - 1))
        return pool[e]
    else:
        return pool[e]

def solve(N, E):

    par = [-1] * N
    order = []
    q = deque([0])
    par[0] = -2 
    
    head = 0
    while head < len(order) + 1:
        if head == len(order):
            if not q:
                break
            i = q.popleft()
            order.append(i)
        else:
            i = order[head]
            
        for j in E[i]:
            if j == par[i]:
                continue
            par[j] = i
            q.append(j)
        head += 1
        
    up = [(0, 0)] * N
    up2 = [0] * N

    for i in reversed(order):
        children_data = []
        hash_sum = (0, 0)
        
        for j in E[i]:
            if j == par[i]:
                continue
            hash_sum = ((hash_sum[0] + up[j][0]) % M, (hash_sum[1] + up[j][1]) % M)
            children_data.append((up[j], up2[j]))
        
        up[i] = f(hash_sum)
        
        children_data.sort()
        unique_children_data = [k for k, g in groupby(children_data)]
        
        up2[i] = sum(e[1] for e in unique_children_data) + 1

    down = [(0, 0)] * N
    down2 = [0] * N
    dp2 = [0] * N

    for i in order:
        mp = defaultdict(int)
        child_indices = []
        children_data_map = {} 

        if par[i] != -2: 
            mp[(down[i], down2[i])] += 1
            
        for j in E[i]:
            if j == par[i]:
                continue
            child_data = (up[j], up2[j])
            mp[child_data] += 1
            child_indices.append(j)
            children_data_map[j] = child_data

        sum_dp2 = sum(data[1] for data in mp) + 1
        dp2[i] = sum_dp2
        
        pre = down[i] 
        for j in child_indices:
            down[j] = pre 
            h_j = children_data_map[j][0]
            pre = ((pre[0] + h_j[0]) % M, (pre[1] + h_j[1]) % M)

        suf = (0, 0)
        for j in reversed(child_indices):
            down[j] = ((down[j][0] + suf[0]) % M, (down[j][1] + suf[1]) % M)
            down[j] = f(down[j])
            
            child_data = children_data_map[j]
            if mp[child_data] == 1:
                down2[j] = sum_dp2 - child_data[1]
            else:
                down2[j] = sum_dp2
            
            h_j = children_data_map[j][0]
            suf = ((suf[0] + h_j[0]) % M, (suf[1] + h_j[1]) % M)

    final_set = {} 
    
    for i in range(N):
        tmp = (0, 0)
        
        if par[i] != -2:
            tmp = ((tmp[0] + down[i][0]) % M, (tmp[1] + down[i][1]) % M)
            
        for j in E[i]:
            if j == par[i]:
                continue
            tmp = ((tmp[0] + up[j][0]) % M, (tmp[1] + up[j][1]) % M)
        
        final_set[tmp] = dp2[i]

    ans = sum(final_set.values())
    return ans

N = int(input())
E = [[] for _ in range(N)]
for _ in range(N - 1):
    u, v = map(int, input().split())
    E[u - 1].append(v - 1)
    E[v - 1].append(u - 1)
print(solve(N, E))

0