結果

問題 No.1424 Ultrapalindrome
ユーザー zkouzkou
提出日時 2021-03-13 13:06:37
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 308 ms / 2,000 ms
コード長 4,435 bytes
コンパイル時間 176 ms
コンパイル使用メモリ 82,560 KB
実行使用メモリ 121,488 KB
最終ジャッジ日時 2024-04-23 05:39:21
合計ジャッジ時間 5,590 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 45 ms
56,192 KB
testcase_01 AC 45 ms
56,704 KB
testcase_02 AC 46 ms
56,192 KB
testcase_03 AC 46 ms
56,320 KB
testcase_04 AC 46 ms
56,576 KB
testcase_05 AC 45 ms
56,704 KB
testcase_06 AC 46 ms
56,192 KB
testcase_07 AC 46 ms
56,448 KB
testcase_08 AC 46 ms
56,320 KB
testcase_09 AC 245 ms
92,636 KB
testcase_10 AC 253 ms
92,544 KB
testcase_11 AC 177 ms
86,912 KB
testcase_12 AC 200 ms
90,088 KB
testcase_13 AC 117 ms
80,000 KB
testcase_14 AC 188 ms
89,056 KB
testcase_15 AC 52 ms
63,232 KB
testcase_16 AC 157 ms
84,608 KB
testcase_17 AC 173 ms
87,168 KB
testcase_18 AC 160 ms
83,796 KB
testcase_19 AC 205 ms
87,444 KB
testcase_20 AC 112 ms
78,720 KB
testcase_21 AC 170 ms
85,312 KB
testcase_22 AC 137 ms
82,732 KB
testcase_23 AC 93 ms
78,588 KB
testcase_24 AC 107 ms
78,720 KB
testcase_25 AC 106 ms
78,592 KB
testcase_26 AC 217 ms
89,132 KB
testcase_27 AC 308 ms
100,600 KB
testcase_28 AC 156 ms
91,676 KB
testcase_29 AC 169 ms
94,208 KB
testcase_30 AC 192 ms
121,488 KB
testcase_31 AC 180 ms
112,944 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import deque


class LCA:
    "0-indexed"

    __slots__ = ["depth", "ancestor"]

    def __init__(self, adj):
        N = len(adj)
        parent = [-1] * N
        self.depth = [0] * N
        q = deque([0])
        while q:
            node = q.popleft()
            for next_node in adj[node]:
                if parent[node] != next_node:
                    parent[next_node] = node
                    q.append(next_node)
                    self.depth[next_node] = self.depth[node] + 1

        self.ancestor = [parent]  # self.ancestor[k][u]はuの2**k先の祖先。
        k = 1
        while (1 << k) < N:
            anc_k = [0] * N
            for u in range(N):
                if self.ancestor[-1][u] == -1:
                    anc_k[u] = -1
                else:
                    anc_k[u] = self.ancestor[-1][self.ancestor[-1][u]]
            self.ancestor.append(anc_k)
            k += 1

    def lca(self, u, v):
        if self.depth[u] < self.depth[v]:
            u, v = v, u
        for k, bit in enumerate(reversed(format(self.depth[u]-self.depth[v], 'b'))):
            if bit == '1':
                u = self.ancestor[k][u]
        if u == v:
            return u
        for anc in reversed(self.ancestor):
            if anc[u] != anc[v]:
                u = anc[u]
                v = anc[v]
        return self.ancestor[0][u]

    def dist(self, u, v):
        w = self.lca(u, v)
        return self.depth[u] + self.depth[v] - 2 * self.depth[w]


def slow(N, adj):
    lca = LCA(adj)

    leaves = [v for v in range(N) if len(adj[v]) == 1]

    dist = lca.dist(leaves[0], leaves[1])
    for i in range(len(leaves)):
        for j in range(i + 1, len(leaves)):
            if dist != lca.dist(leaves[i], leaves[j]):
                return False
    return True


def dfs(N, adj, root):
    stack = [root]
    parent = [-1] * N
    parent[root] = root
    depth = [0] * N
    while stack:
        v = stack.pop()
        for nv in adj[v]:
            if parent[nv] == -1:
                parent[nv] = v
                depth[nv] = depth[v] + 1
                stack.append(nv)
    return parent, depth


def solve(N, adj):
    _, depth_0 = dfs(N, adj, 0)
    root = max(range(N), key=lambda i: depth_0[i])
    parent, depth = dfs(N, adj, root)
    dist = max(depth)
    leaves = [v for v in range(N) if len(adj[v]) == 1]
    if dist % 2 == 1:
        if len(leaves) == 2:
            return True
        else:
            return False
    else:
        end = max(range(N), key=lambda i: depth[i])
        cent = end
        for _ in range(dist // 2):
            cent = parent[cent]
        parent, dists = dfs(N, adj, cent)
        used = [False] * N
        for l in leaves:
            if dists[l] != dist // 2:
                return False
            while l != cent:
                if used[l]:
                    return False
                used[l] = True
                l = parent[l]
        return True


def main():
    N = int(input())
    adj = [[] for _ in range(N)]
    for _ in range(N - 1):
        u, v = map(int, input().split())
        u -= 1
        v -= 1
        adj[u].append(v)
        adj[v].append(u)
    print("Yes" if solve(N, adj) else "No")


import random

def prufer_decode(n, code):
    "0-indexed. O(n) time."
    assert len(code) == n - 2
    counts = [0] * n
    for v in code:
        counts[v] += 1
    index = 0
    while counts[index]:
        index += 1
    leaf = index
    edges = []
    for v in code:
        edges.append((leaf, v))
        counts[v] -= 1
        if counts[v] == 0 and v < index:
            leaf = v
        else:
            index += 1
            while counts[index]:
                index += 1
            leaf = index
    edges.append((leaf, n - 1))
    return edges

def tree_generator(n):
    assert n >= 2
    code = [random.randrange(n) for _ in range(n - 2)]
    edges = prufer_decode(n, code)
    # random.shuffle(edges)
    return edges


def test(t):
    for _ in range(t):
        N = random.randint(2, 10 ** 3)
        edges = tree_generator(N)
        adj = [[] for _ in range(N)]
        for u, v in edges:
            adj[u].append(v)
            adj[v].append(u)
        if not slow(N, adj) == solve(N, adj):
            print(slow(N, adj), solve(N, adj))
            print(N)
            for u, v in edges:
                print(u + 1, v + 1)
            raise AssertionError

# test(10 ** 3)
main()
0