結果

問題 No.3206 う し た ウ ニ 木 あ く ん 笑
ユーザー Koi
提出日時 2025-07-18 22:29:20
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 2,093 bytes
コンパイル時間 226 ms
コンパイル使用メモリ 82,356 KB
実行使用メモリ 68,088 KB
最終ジャッジ日時 2025-07-18 22:29:24
合計ジャッジ時間 3,664 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample RE * 2
other RE * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

from sortedcontainers import SortedList
N = int(input())
Graph = [[] for _ in range(N)]
for _ in range(N - 1):
    a, b = map(int, input().split())
    a -= 1
    b -= 1
    Graph[a].append(b)
    Graph[b].append(a)

done = [0] * N
Q = [~0, 0] # 根をスタックに追加
ET = []
dp = [0] * N
L = [SortedList([]) for _ in range(N)]
Parent = [-1 for _ in range(N)]
Children = [[] for _ in range(N)]
while Q:
    i = Q.pop()
    if i >= 0: # 行きがけの処理
        done[i] = 1
        ET.append(i)
        for a in Graph[i][::-1]:
            if done[a]: continue
            Parent[a] = i
            Children[i].append(a)
            Q.append(~a) # 帰りがけの処理をスタックに追加
            Q.append(a) # 行きがけの処理をスタックに追加
    
    else: # 帰りがけの処理
        i = ~i
        ET.append(i)
        for x in Children[i]:
            L[i].add(dp[x])
        if(len(L[i]) > 0):
            dp[i] = L[i][-1] + 1
        else:
            dp[i] = 1
        # if(len(L[i]) >= 3):
        #     dp[i] = L[i][-1] + L[i][-2] + L[i][-3] + 1
        # else:
        #     dp[i] = 0
# print(ET)
# print(L)
# print(dp)
Seen = [False] * N
s = 0
ans = 0
# for x in Graph[0]:
#     ans = max(ans, dp[x])
for i in range(len(L[0])):
    ans = max(ans, L[0][i] * (len(L[0]) - i) + 1)
# print(ans)
for i in range(1, len(ET) - 1):
    x = ET[i]
    if(not Seen[x]):
        #pre -> x
        Seen[x] = True
        t = x
    else:
        #x -> Parent[x]
        t = Parent[x]
    #(sからtに交換)
    L[s].remove(dp[t])
    # if(len(L[s]) >= 3):
    #     dp[s] = L[s][-1] + L[s][-2] + L[s][-3] + 1
    # else:
    #     dp[s] = 0
    if(len(L[s]) > 0):
        dp[s] = L[s][-1] + 1
    else:
        dp[s] = 1
    L[t].add(dp[s])
    if(len(L[t]) > 0):
        dp[t] = L[t][-1] + 1
    else:
        dp[t] = 1
    # print(s, t)
    # print(L)
    # print(dp)
    # print(L[t])
    # print()
    for i in range(len(L[t])):
        ans = max(ans, L[t][i] * (len(L[t]) - i) + 1)
    s = t
if(ans == 0):
    print(-1)
else:
    print(ans)
# print(ans)
0