結果

問題 No.3309 Aging Railway
コンテスト
ユーザー hato336
提出日時 2025-10-24 21:46:25
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,020 bytes
コンパイル時間 295 ms
コンパイル使用メモリ 82,312 KB
実行使用メモリ 81,080 KB
最終ジャッジ日時 2025-10-24 21:46:43
合計ジャッジ時間 11,873 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 3 WA * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import math,collections,sys
class lca_doubling():
    def __init__(self,n,edge):
        self.n = n
        self.edge = edge
        self.dist = [10**20] * n
        self.k = math.floor(math.log2(n))
        self.ansestor = [[-1 for i in range(n)] for j in range(self.k + 1)]
        self.bfs()
        self.p = [1<<60 for i in range(n)]
        self.p2 = [[1<<60 for i in range(n)] for j in range(self.k + 1)]
        for i in range(self.k):
            for j in range(n):
               if self.ansestor[i][j] == -1:
                   self.ansestor[i+1][j] = -1
               else:
                   self.ansestor[i+1][j] = self.ansestor[i][self.ansestor[i][j]]
       
    def bfs(self):
        d = collections.deque()
        d.append(0)
        mou = set()
        mou.add(0)
        self.dist[0] = 0
        while d:
            now = d.popleft()
            for i in self.edge[now]:
                if i not in mou:
                    self.ansestor[0][i] = now
                    d.append(i)
                    mou.add(i)
                    self.dist[i] = self.dist[now] + 1
    
    def ans(self,u,a,w):
        for i in reversed(range(self.k + 1)):
            if u == -1:
                break
            if (a >> i) & 1:
                w = min(w,self.p2[i][u])
                u = self.ansestor[i][u]
        return u,w
    
    def p2calc(self):
        for i in range(self.k):
            for j in range(n):
               if self.ansestor[i][j] == -1:
                   self.p2[i+1][j] = self.p2[i][j]
               else:
                   self.p2[i+1][j] = min(self.p2[i][self.ansestor[i][j]], self.p2[i][j])
    
    def calc(self,s,t):
        sw = 1<<60
        tw = 1<<60
        if self.dist[s] > self.dist[t]:
            s,sw = self.ans(s,self.dist[s] - self.dist[t],sw)
        if self.dist[s] < self.dist[t]:
            t,tw = self.ans(t,-self.dist[s] + self.dist[t],tw)
        for i in reversed(range(self.k)):
            ns = self.ansestor[i][s]
            nt = self.ansestor[i][t]
            if ns != nt:
                sw = min(sw,self.p2[i][s])
                tw = min(tw,self.p2[i][t])
                s = ns
                t = nt
        if self.ansestor[0][s] != -1 and s != t:
            sw = min(sw,self.p2[0][s])

        x = (0 if self.ansestor[0][s] == -1 else s if s == t else self.ansestor[0][s])
        return x,min(sw,tw)
    
input = sys.stdin.readline
n,m = map(int,input().split())
edge = [[] for i in range(n)]
e = []
for i in range(n-1):
    u,v = map(int,input().split())
    u -= 1
    v -= 1
    edge[u].append(v)
    edge[v].append(u)
    e.append((u,v))

l = lca_doubling(n,edge)

for i in range(n-1):
    u,v = e[i]
    if l.dist[u] > l.dist[v]:
        u,v = v,u
    l.p[v] = i
    l.p2[0][v] = i

l.p2calc()

ans = [0 for i in range(n)]
for i in range(m):
    u,v = map(int,input().split())
    u -= 1
    v -= 1
    w,r = l.calc(u,v)

    ans[0] += 1
    ans[r] -= 1
for i in range(n-1):
    ans[i+1] += ans[i]
ans.pop()
print(*ans,sep='\n')
0