結果

問題 No.1002 Twotone
ユーザー mkawa2
提出日時 2025-05-16 00:27:24
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 3,634 ms / 5,000 ms
コード長 3,251 bytes
コンパイル時間 421 ms
コンパイル使用メモリ 82,232 KB
実行使用メモリ 251,720 KB
最終ジャッジ日時 2025-05-16 00:28:23
合計ジャッジ時間 57,137 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

sys.setrecursionlimit(200005)
# sys.set_int_max_str_digits(200005)
int1 = lambda x: int(x)-1
pDB = lambda *x: print(*x, end="\n", file=sys.stderr)
p2D = lambda x: print(*x, sep="\n", end="\n\n", file=sys.stderr)
def II(): return int(sys.stdin.readline())
def LI(): return list(map(int, sys.stdin.readline().split()))
def LLI(rows_number): return [LI() for _ in range(rows_number)]
def LI1(): return list(map(int1, sys.stdin.readline().split()))
def LLI1(rows_number): return [LI1() for _ in range(rows_number)]
def SI(): return sys.stdin.readline().rstrip()

# dij = [(0, 1), (-1, 0), (0, -1), (1, 0)]
dij = [(0, 1), (-1, 0), (0, -1), (1, 0), (1, 1), (1, -1), (-1, 1), (-1, -1)]
# inf = -1-(-1 << 31)
inf = -1-(-1 << 62)

# md = 10**9+7
md = 998244353

# 重心を再帰的に求める
def centroid_finder(to, root=0):
    centroids = []
    pre_cent = []
    subtree_size = []
    n = len(to)
    roots = [(root, -1, 1)]
    size = [1]*n
    is_removed = [0]*n
    parent = [-1]*n
    while roots:
        root, pc, update = roots.pop()
        parent[root] = -1
        if update:
            stack = [root]
            dfs_order = []
            while stack:
                u = stack.pop()
                size[u] = 1
                dfs_order.append(u)
                for v,_ in to[u]:
                    if v == parent[u] or is_removed[v]: continue
                    parent[v] = u
                    stack.append(v)
            for u in dfs_order[::-1]:
                if u == root: break
                size[parent[u]] += size[u]
        c = root
        while 1:
            mx, u = size[root]//2, -1
            for v,_ in to[c]:
                if v == parent[c] or is_removed[v]: continue
                if size[v] > mx: mx, u = size[v], v
            if u == -1: break
            c = u
        centroids.append(c)
        pre_cent.append(pc)
        subtree_size.append(size[root])
        is_removed[c] = 1
        for v,_ in to[c]:
            if is_removed[v]: continue
            roots.append((v, c, v == parent[c]))
    return centroids

from collections import Counter

n,k=LI()
to=[[] for _ in range(n)]
for _ in range(n-1):
    u,v,c=LI1()
    to[u].append((v,c))
    to[v].append((u,c))

def dfs(v,c,r):
    cr,sr=Counter(),Counter()
    cr[-1,c]=1
    sr[-1]=1
    if fin[v]:return cr,sr
    st=[(v,r,-1,c)]
    while st:
        u,p,x,y=st.pop()
        for v,c in to[u]:
            if v==p:continue
            if x!=-1 and x!=c and y!=c:continue
            nx,ny=x,y
            if nx==-1 and ny!=c:
                nx=c
                if nx>ny:nx,ny=ny,nx
            cr[nx,ny]+=1
            if nx==-1:
                sr[-1]+=1
            else:
                sr[nx]+=1
                sr[ny]+=1
            if fin[v]:continue
            st.append((v,u,nx,ny))
    return cr,sr

rr=centroid_finder(to)
# print(rr)
fin=[0]*n
ans=0
for r in rr:
    fin[r]=1
    cnt,sub=Counter(),Counter()
    for v,c in to[r]:
        cr,sr=dfs(v,c,r)
        for x,y in cr:
            if x==-1:
                ans+=cr[x,y]*(sub[-1]-cnt[-1,y]+sub[y])
            else:
                ans+=cr[x,y]*(cnt[-1,x]+cnt[-1,y]+cnt[x,y])
        # print(ans,v,cr,sr,cnt,sub)
        cnt+=cr
        sub+=sr

print(ans)
0