結果
| 問題 |
No.3346 Tree to DAG
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-10-29 17:47:48 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 672 ms / 2,000 ms |
| コード長 | 5,272 bytes |
| コンパイル時間 | 285 ms |
| コンパイル使用メモリ | 82,728 KB |
| 実行使用メモリ | 130,696 KB |
| 最終ジャッジ日時 | 2025-11-13 21:06:00 |
| 合計ジャッジ時間 | 15,635 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 39 |
ソースコード
class TopK:
def __init__(self,K,e):
self.array = [e]*K
self._index = 0
self.K = K
self.e = e
def add(self,x):
for i in range(self.K):
if x >= self.array[i]:
self.array = self.array[:i] + [x] + self.array[i:-1]
break
def __str__(self):
return str(self.array)
from collections import Counter
def normalize_bits(bits):
"""
ビット(指数)のリストを受け取り、繰り上がり(2^k + 2^k = 2^{k+1})
を全て処理して正規化された(重複のない)ビットリストを返す。
例: [5, 5, 3] -> [6, 3]
[5, 5, 6] -> [7]
"""
if not bits:
return []
bits_sorted = sorted(bits,reverse=True)
stack = []
for i in range(len(bits_sorted)):
b = bits_sorted[i]
while len(stack) > 0 and stack[-1] == b:
b = stack.pop()+1
stack.append(b)
#print(stack)
return stack
def compare_fractions(abc, def_):
"""
非負整数の3つ組 (a,b,c) と (d,e,f) を受け取り、
F(a,b,c) = (2^{a+1}+2^{b+1}+2^{c+1}-3) / 2^{a+b+c}
F(d,e,f) = (2^{d+1}+2^{e+1}+2^{f+1}-3) / 2^{d+e+f}
の大小を比較する。
戻り値:
- 1 : F(a,b,c) > F(d,e,f)
- 0 : F(a,b,c) == F(d,e,f)
- -1 : F(a,b,c) < F(d,e,f)
"""
a, b, c = abc
d, e, f = def_
A = a + b + c
D = d + e + f
# 方針2で導出した両辺の指数リストを作成
lhs_bits = [
a + 1 + D,
b + 1 + D,
c + 1 + D,
A + 1,
A
]
rhs_bits = [
d + 1 + A,
e + 1 + A,
f + 1 + A,
D + 1,
D
]
# 方針3: ビットリストを正規化
norm_lhs = normalize_bits(lhs_bits)
norm_rhs = normalize_bits(rhs_bits)
# 方針4: 正規化済みリストを辞書順に比較
if norm_lhs > norm_rhs:
return 1
elif norm_lhs < norm_rhs:
return -1
else:
return 0
# ===================================
from collections import deque
MOD = 998244353
N = int(input())
connect = [[] for _ in range(N)]
for _ in range(N-1):
u,v = map(lambda x:int(x)-1,input().split())
connect[u].append(v)
connect[v].append(u)
top3_array = [TopK(3,0) for _ in range(N)] # ある頂点から見た部分木の中で深さtop3のやつ
# 子から親への頂点を削除 O(M)
parents = [-1] * N
tps = []
Q = deque([0])
while Q:
i = Q.popleft()
tps.append(i)
for a in connect[i]:
if a != parents[i]:
parents[a] = i
connect[a].remove(i)
Q.append(a)
#print(connect)
#print(parents)
#print(tps)
# Bottom-Up DP
acc_BU = [0] * N
res_BU = [0] * N
for i in tps[1:][::-1]:
res_BU[i] = acc_BU[i] + 1
p = parents[i]
acc_BU[p] = max(acc_BU[p], res_BU[i])
top3_array[p].add(res_BU[i])
res_BU[tps[0]] = acc_BU[tps[0]] + 1
#print(res_BU)
#print([x.array for x in top3_array])
#Top-Down DP
acc_TD = [0] * N
res_TD = [0] * N
res = [0]*N
# 1. 初期化の修正
res_TD[0] = 0 # 根の「親側」の深さは0
res[0] = max(res_BU[0], res_TD[0]) # 根の最大深さ
top3_array[0].add(res_TD[0]) # 根のtop3にも親側(0)を追加
for i in tps:
# 2. array の構築を修正
# 子のBU深さリスト + 親側(TD)の深さ
array = [res_BU[c] for c in connect[i]]
array.append(res_TD[i]) # res[parents[i]] ではなく res_TD[i]
# print("Debug",i,parents[i],array) # デバッグ用
L = len(array)
# 右向き累積max (左からi-1番目までのmax)
accr_cum = [-float("inf")]*(len(array)+1)
for j,v in enumerate(array):
accr_cum[j+1] = max(accr_cum[j],v)
# 左向き累積max (右からi+1番目までのmax)
accl_cum = [-float("inf")]*(len(array)+1)
for j,v in enumerate(reversed(array)):
accl_cum[L-j-1] = max(v,accl_cum[L-j])
# print(accr_cum, accl_cum) # デバッグ用
# 3. acc_TD, res_TD の更新ロジックを修正
for j,c in enumerate(connect[i]):
# c 以外の方向の最大深さを計算
# (arrayのj番目が c のBU深さ だった)
acc_TD_for_c = max(accr_cum[j], accl_cum[j+1])
# c にとっての「親側(i)からの深さ」は acc_TD_for_c + 1
res_TD[c] = acc_TD_for_c + 1
# c の最大深さ(偏心度)を更新
res[c] = max(res_BU[c], res_TD[c])
# c のtop3に「親側からの深さ」を追加
top3_array[c].add(res_TD[c])
# print("add",j,c,res[c],top3_array[c]) # デバッグ用
# print(res_TD) # デバッグ用
# resは既に計算済みなので、最後の行は不要 (あっても良いが冗長)
# res = [max(res_BU[i],res_TD[i]) for i in range(N)]
#print(res)
#print([x.array for x in top3_array])
def f(t):
a,b,c = t
K = a+b+c
return 2**(N+2) - 2**(N-K)*(2**(a+1) + 2**(b+1) + 2**(c+1) - 3)
M = [0,0,0]
for i in range(N):
x = top3_array[i].array
#print(f(x),f(M),compare_fractions(x,M))
if compare_fractions(M,x) == 1:
M = x
a,b,c = M
K = sum(M)
#print(a,b,c)
ans = pow(2,N+2,MOD) - pow(2,N-K,MOD) * (pow(2,a+1,MOD) + pow(2,b+1,MOD) + pow(2,c+1,MOD) - 3)
ans %= MOD
print(ans)