結果
| 問題 | No.1718 Random Squirrel | 
| コンテスト | |
| ユーザー |  | 
| 提出日時 | 2022-03-16 19:36:32 | 
| 言語 | PyPy3 (7.3.15) | 
| 結果 | 
                                AC
                                 
                             | 
| 実行時間 | 448 ms / 2,000 ms | 
| コード長 | 2,644 bytes | 
| コンパイル時間 | 266 ms | 
| コンパイル使用メモリ | 82,192 KB | 
| 実行使用メモリ | 120,976 KB | 
| 最終ジャッジ日時 | 2024-09-24 20:18:29 | 
| 合計ジャッジ時間 | 8,826 ms | 
| ジャッジサーバーID (参考情報) | judge2 / judge1 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 2 | 
| other | AC * 31 | 
ソースコード
N,K = map(int,input().split())
G = [[] for _ in range(N+1)]
for _ in range(N-1):
    u,v = map(int,input().split())
    G[u].append(v)
    G[v].append(u)
D = list(map(int,input().split()))
s = set(D)
depth = [0] * (N + 1)
dp = [0] * (N + 1)
stack = [(~1,0),(1,0)]
while stack:
    now,parent = stack.pop()
    if now >= 0:
        for v in G[now]:
            if v == parent:continue
            stack.append((~v,now))
            stack.append((v,now))
    else:
        now = ~now
        for v in G[now]:
            if v == parent:continue
            if depth[v] != 0:
                depth[now] = max(depth[now],depth[v] + 1)
            else:
                if v in s:
                    depth[now] = max(depth[now],1)
            if dp[v] != 0:
                dp[now] += dp[v] + 2
            else:
                if v in s:
                    dp[now] += 2
ans = [0] * (N + 1)
depth2 = [0] * (N +1)
dp2 = [0] * (N + 1)
stack = [(1,0)]
while stack:
    now,parent = stack.pop()
    n = len(G[now])
    hidarid = [0] * (n+1)
    migid = [0] * (n + 1)
    hidaridp = [0] * (n + 1)
    migidp = [0] * (n + 1)
    for i in range(n):
        v = G[now][i]
        hidarid[i+1] = hidarid[i]
        hidaridp[i+1] = hidaridp[i]
        if v == parent:continue
        if depth[v] != 0:
            hidarid[i+1] = max(hidarid[i+1],depth[v]+1)
            hidaridp[i+1] += dp[v] + 2
        else:
            if v in s:
                hidarid[i+1] = max(hidarid[i+1],1)
                hidaridp[i+1] += 2
    for i in reversed(range(n)):
        v = G[now][i]
        migid[i] = migid[i+1]
        migidp[i] = migidp[i+1]
        if v == parent:continue
        if depth[v] != 0:
            migid[i] = max(migid[i],depth[v] + 1)
            migidp[i] += dp[v] + 2
        else:
            if v in s:
                migid[i] = max(migid[i],1)
                migidp[i] += 2
    for i in range(n):
        v = G[now][i]
        if v == parent:continue
        tmpd = max(hidarid[i],migid[i+1],depth2[now])
        if depth2[now] == 0:
            if parent in s:
                tmpd = max(tmpd,1)
        tmpdp = hidaridp[i] + migidp[i+1] + dp2[now]
        if dp2[now] == 0:
            if parent in s:
                tmpdp += 2
        if tmpd != 0:
            depth2[v] = tmpd + 1
            dp2[v] = tmpdp + 2
        else:
            if now in s:
                depth2[v] = 1
                dp2[v] = 2
        stack.append((v,now))
    
    ans[now] = hidaridp[-1] + dp2[now] - max(depth2[now],hidarid[-1])
    if dp2[now] == 0:
        if parent in s:
            ans[now] += 2
for i in ans[1:]:
    print(i)
#print(dp,depth)
            
            
            
        