結果
| 問題 | 
                            No.1424 Ultrapalindrome
                             | 
                    
| コンテスト | |
| ユーザー | 
                             | 
                    
| 提出日時 | 2021-03-13 13:06:37 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                AC
                                 
                             
                            
                         | 
                    
| 実行時間 | 289 ms / 2,000 ms | 
| コード長 | 4,435 bytes | 
| コンパイル時間 | 159 ms | 
| コンパイル使用メモリ | 82,176 KB | 
| 実行使用メモリ | 121,372 KB | 
| 最終ジャッジ日時 | 2024-10-15 05:23:56 | 
| 合計ジャッジ時間 | 5,529 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge4 / judge2 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 29 | 
ソースコード
from collections import deque
class LCA:
    "0-indexed"
    __slots__ = ["depth", "ancestor"]
    def __init__(self, adj):
        N = len(adj)
        parent = [-1] * N
        self.depth = [0] * N
        q = deque([0])
        while q:
            node = q.popleft()
            for next_node in adj[node]:
                if parent[node] != next_node:
                    parent[next_node] = node
                    q.append(next_node)
                    self.depth[next_node] = self.depth[node] + 1
        self.ancestor = [parent]  # self.ancestor[k][u]はuの2**k先の祖先。
        k = 1
        while (1 << k) < N:
            anc_k = [0] * N
            for u in range(N):
                if self.ancestor[-1][u] == -1:
                    anc_k[u] = -1
                else:
                    anc_k[u] = self.ancestor[-1][self.ancestor[-1][u]]
            self.ancestor.append(anc_k)
            k += 1
    def lca(self, u, v):
        if self.depth[u] < self.depth[v]:
            u, v = v, u
        for k, bit in enumerate(reversed(format(self.depth[u]-self.depth[v], 'b'))):
            if bit == '1':
                u = self.ancestor[k][u]
        if u == v:
            return u
        for anc in reversed(self.ancestor):
            if anc[u] != anc[v]:
                u = anc[u]
                v = anc[v]
        return self.ancestor[0][u]
    def dist(self, u, v):
        w = self.lca(u, v)
        return self.depth[u] + self.depth[v] - 2 * self.depth[w]
def slow(N, adj):
    lca = LCA(adj)
    leaves = [v for v in range(N) if len(adj[v]) == 1]
    dist = lca.dist(leaves[0], leaves[1])
    for i in range(len(leaves)):
        for j in range(i + 1, len(leaves)):
            if dist != lca.dist(leaves[i], leaves[j]):
                return False
    return True
def dfs(N, adj, root):
    stack = [root]
    parent = [-1] * N
    parent[root] = root
    depth = [0] * N
    while stack:
        v = stack.pop()
        for nv in adj[v]:
            if parent[nv] == -1:
                parent[nv] = v
                depth[nv] = depth[v] + 1
                stack.append(nv)
    return parent, depth
def solve(N, adj):
    _, depth_0 = dfs(N, adj, 0)
    root = max(range(N), key=lambda i: depth_0[i])
    parent, depth = dfs(N, adj, root)
    dist = max(depth)
    leaves = [v for v in range(N) if len(adj[v]) == 1]
    if dist % 2 == 1:
        if len(leaves) == 2:
            return True
        else:
            return False
    else:
        end = max(range(N), key=lambda i: depth[i])
        cent = end
        for _ in range(dist // 2):
            cent = parent[cent]
        parent, dists = dfs(N, adj, cent)
        used = [False] * N
        for l in leaves:
            if dists[l] != dist // 2:
                return False
            while l != cent:
                if used[l]:
                    return False
                used[l] = True
                l = parent[l]
        return True
def main():
    N = int(input())
    adj = [[] for _ in range(N)]
    for _ in range(N - 1):
        u, v = map(int, input().split())
        u -= 1
        v -= 1
        adj[u].append(v)
        adj[v].append(u)
    print("Yes" if solve(N, adj) else "No")
import random
def prufer_decode(n, code):
    "0-indexed. O(n) time."
    assert len(code) == n - 2
    counts = [0] * n
    for v in code:
        counts[v] += 1
    index = 0
    while counts[index]:
        index += 1
    leaf = index
    edges = []
    for v in code:
        edges.append((leaf, v))
        counts[v] -= 1
        if counts[v] == 0 and v < index:
            leaf = v
        else:
            index += 1
            while counts[index]:
                index += 1
            leaf = index
    edges.append((leaf, n - 1))
    return edges
def tree_generator(n):
    assert n >= 2
    code = [random.randrange(n) for _ in range(n - 2)]
    edges = prufer_decode(n, code)
    # random.shuffle(edges)
    return edges
def test(t):
    for _ in range(t):
        N = random.randint(2, 10 ** 3)
        edges = tree_generator(N)
        adj = [[] for _ in range(N)]
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)
        if not slow(N, adj) == solve(N, adj):
            print(slow(N, adj), solve(N, adj))
            print(N)
            for u, v in edges:
                print(u + 1, v + 1)
            raise AssertionError
# test(10 ** 3)
main()