結果

問題 No.3346 Tree to DAG
コンテスト
ユーザー Nzt3
提出日時 2025-11-09 17:25:34
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 298 ms / 2,000 ms
コード長 3,242 bytes
コンパイル時間 265 ms
コンパイル使用メモリ 82,664 KB
実行使用メモリ 110,392 KB
最終ジャッジ日時 2025-11-13 21:12:52
合計ジャッジ時間 7,471 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 39
権限があれば一括ダウンロードができます

ソースコード

diff #

class ModInt:
    def __init__(self, value):
        self.v = value % 998244353

    def __add__(self, other):
        return ModInt(self.v + (other.v if isinstance(other, ModInt) else other))

    def __sub__(self, other):
        return ModInt(self.v - (other.v if isinstance(other, ModInt) else other))

    def __mul__(self, other):
        return ModInt(self.v * (other.v if isinstance(other, ModInt) else other))

    def __pow__(self, exp):
        return ModInt(pow(self.v, exp, 998244353))

    def __truediv__(self, other):
        return ModInt(self.v * pow(other.v if isinstance(other, ModInt) else other, 998244353 - 2, 998244353))

    def __int__(self):
        return self.v

    def __repr__(self):
        return str(self.v)


def greater_check(a, b):
    # a,b: list of 3 ints
    abit = [0] * 5
    bbit = [0] * 5
    ak = sum(a)
    bk = sum(b)
    for i in range(3):
        abit[i] = a[i] + 1 + bk
        bbit[i] = b[i] + 1 + ak
    abit[3], abit[4] = ak + 1, ak
    bbit[3], bbit[4] = bk + 1, bk

    def bitsort(bit):
        for _ in range(5):
            for j in range(1, 5):
                if bit[j] != -1 and bit[j] == bit[j - 1]:
                    bit[j - 1] = bit[j] + 1
                    bit[j] = -1
                elif bit[j] > bit[j - 1]:
                    bit[j], bit[j - 1] = bit[j - 1], bit[j]

    bitsort(abit)
    bitsort(bbit)

    for i in range(5):
        if abit[i] > bbit[i]:
            return True
        if abit[i] < bbit[i]:
            return False
    return False


def bfs(G, N, start):
    par = [-1] * N
    look = []
    for s in start:
        par[s] = s
        look.append(s)
    l = 0
    while l < len(look):
        v = look[l]
        l += 1
        for u in G[v]:
            if par[u] == -1:
                par[u] = v
                look.append(u)
    return look, par


def main():
    import sys
    sys.setrecursionlimit(10**7)
    input = sys.stdin.readline

    N = int(input())
    G = [[] for _ in range(N)]
    for _ in range(N - 1):
        U, V = map(int, input().split())
        U -= 1
        V -= 1
        G[U].append(V)
        G[V].append(U)

    # 1. BFS from node 0
    dist, par = bfs(G, N, [0])
    d1 = dist[-1]
    # 2. BFS from farthest node
    dist, par = bfs(G, N, [d1])
    d2 = dist[-1]

    # 3. Recover diameter path
    diameter = []
    t = d2
    while True:
        diameter.append(t)
        if par[t] == t:
            break
        t = par[t]

    # 4. BFS from all diameter nodes
    dist, par = bfs(G, N, diameter)

    # 5. Compute depth array (max depth from each node)
    depth = [0] * N
    for i in reversed(dist):
        if par[i] != i:
            depth[par[i]] = max(depth[par[i]], depth[i] + 1)
        else:
            break

    ansd = [0, 0, 0]
    t = 0
    for i in diameter:
        now = [t, len(diameter) - 1 - t, depth[i]]
        now.sort()
        if greater_check(ansd, now):
            ansd = now
        t += 1

    ans = ModInt(2) ** (N + 2)
    K = sum(ansd)
    ans -= (ModInt(2) ** (N - K - 1)) * (
        (ModInt(2) ** (ansd[0] + 2))
        + (ModInt(2) ** (ansd[1] + 2))
        + (ModInt(2) ** (ansd[2] + 2))
        - 6
    )
    print(int(ans))


if __name__ == "__main__":
    main()
0