結果
| 問題 |
No.3206 う し た ウ ニ 木 あ く ん 笑
|
| コンテスト | |
| ユーザー |
もの
|
| 提出日時 | 2025-07-18 22:43:21 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 6,487 bytes |
| コンパイル時間 | 403 ms |
| コンパイル使用メモリ | 82,612 KB |
| 実行使用メモリ | 452,244 KB |
| 最終ジャッジ日時 | 2025-07-18 22:43:45 |
| 合計ジャッジ時間 | 21,292 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 22 WA * 8 |
ソースコード
############################################################################################
def factorization(n):
arr = []
temp = n
for i in range(2, int(-(-n**0.5//1))+1):
if temp%i==0:
cnt=0
while temp%i==0:
cnt+=1
temp //= i
arr.append([i, cnt])
if temp!=1:
arr.append([temp, 1])
if arr==[]:
arr.append([n, 1])
return arr
def cycle_detection(f, x0):
power = lam = 1
first = x0
second = f(x0)
while first != second:
if power == lam:
first = second
power *= 2
lam = 0
second = f(second)
lam += 1
first = second = x0
for i in range(lam):
second = f(second)
mu = 0
while first != second:
first = f(first)
second = f(second)
mu += 1
return mu, lam
def cycle_detection_lists(f, x0):
mu, lam = cycle_detection(f, x0)
before_cycle = [0] * mu
if mu > 0:
before_cycle[0] = x0
for i in range(mu - 1):
before_cycle[i + 1] = f(before_cycle[i])
cycle = [0] * lam
cycle[0] = before_cycle[-1] if mu > 0 else x0
for i in range(lam - 1):
cycle[i + 1] = f(cycle[i])
return before_cycle, cycle
import bisect,collections,copy,heapq,itertools,math,string,sys,queue,time,random
from decimal import Decimal
def I(): return input()
def IS(): return input().split()
def II(): return int(input())
def IIS(): return list(map(int,input().split()))
def LIIS(): return list(map(int,input().split()))
def comb(n, r):return math.factorial(n) // (math.factorial(n - r) * math.factorial(r))
def make_divisors(n):
lower_divisors , upper_divisors = [], []
i = 1
while i*i <= n:
if n % i == 0:
lower_divisors.append(i)
if i != n // i:
upper_divisors.append(n//i)
i += 1
return lower_divisors + upper_divisors[::-1]
INF=float("inf")
MOD=998244353
MOD2=10**9+7
sys.setrecursionlimit(2*10**5)
alpha="ABCDEFGHIJKLMNOPQRSTUVWXYZ"
def bit_count(x):
return bin(x).count("1")
def yesno(f):
if f:print("Yes")
else:print("No")
# import pypyjit
# pypyjit.set_param('max_unroll_recursion=-1')
def prime_factorization(n):
factors = []
temp = n
for i in range(2, int(-(-n**0.5//1)) + 1):
if temp % i == 0:
count = 0
while temp % i == 0:
count += 1
temp //= i
factors.append((i, count))
if temp != 1:
factors.append((temp, 1))
if not factors:
factors.append((n, 1))
return factors
#n進数表示
def to_base(n, base):
if n == 0:
return [0]
digits = []
while n:
digits.append(n % base)
n //= base
return digits[::-1]
class ncr_object:
def __init__(self, n, mod=MOD):
self.n = n
self.mod = mod
self.fact = [1] * (n + 1)
for i in range(2, n + 1):
self.fact[i] = self.fact[i - 1] * i % mod
def comb(self, n, r):
if n < r or n < 0 or r < 0:
return 0
return self.fact[n] * pow(self.fact[r], self.mod - 2, self.mod) * pow(self.fact[n - r], self.mod - 2, self.mod) % self.mod
def _ntt(a, invert, mod, root):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j |= bit
if i < j:
a[i], a[j] = a[j], a[i]
length = 2
while length <= n:
# wlen = root^((mod-1)/length) mod mod
exp = (mod - 1) // length
wlen = pow(root, exp, mod)
if invert:
wlen = pow(wlen, mod - 2, mod)
for i in range(0, n, length):
w = 1
for k in range(i, i + length // 2):
u = a[k]
v = a[k + length // 2] * w % mod
a[k] = (u + v) % mod
a[k + length // 2] = (u - v + mod) % mod
w = w * wlen % mod
length <<= 1
if invert:
inv_n = pow(n, mod - 2, mod)
for i in range(n):
a[i] = a[i] * inv_n % mod
def convolution_ntt(a, b, mod=998244353, root=3):
"""
Perform convolution of lists a and b under modulus `mod` using NTT.
`root` is a primitive root of `mod`, where mod = k*2^m + 1.
Returns a list of size len(a)+len(b)-1, results mod `mod`.
"""
n = 1
while n < len(a) + len(b) - 1:
n <<= 1
fa = a + [0] * (n - len(a))
fb = b + [0] * (n - len(b))
_ntt(fa, invert=False, mod=mod, root=root)
_ntt(fb, invert=False, mod=mod, root=root)
for i in range(n):
fa[i] = fa[i] * fb[i] % mod
_ntt(fa, invert=True, mod=mod, root=root)
return fa[:len(a) + len(b) - 1]
# print(convolution_ntt([1, 2, 3], [4, 5, 6], mod=998244353, root=3))
# pragma GCC target("avx2")
# pragma GCC optimize("O3")
# pragma GCC optimize("unroll-loops")
from fractions import Fraction as fraction
####################################################
def treedp(path,s):
n=len(path)
used=[0 for i in range(len(path))]
used[s]=1
size=[0 for i in range(len(path))]
def _tree_dp(v,d):
res=d
for j in path[v]:
if used[j]:continue
used[j]=1
res=max(_tree_dp(j,d+1),res)
size[v]=res-d
return res
_tree_dp(s,1)
return size
n=II()
path=[[] for _ in range(n)]
for _ in range(n-1):
u,v=IIS()
u-=1
v-=1
path[u].append(v)
path[v].append(u)
dd=treedp(path,0)
ans=0
def treedp2(path,s):
n=len(path)
used=[0 for i in range(len(path))]
size=[1 for i in range(len(path))]
used[s]=1
def _tree_dp(v,p):
cnt=0
mx=[dd[j] for j in path[v] if not used[j]]
mx.sort()
mn=list([p]+[dd[j]+1 for j in path[v] if not used[j]])
mn.sort(reverse=True)
global ans
for i in range(len(mn)):
ans=max(ans,mn[i]*(i+1)+1)
for j in path[v]:
if used[j]:continue
# used[j]=1
cnt+=mx[-1]==dd[j]
for j in path[v]:
if used[j]:continue
used[j]=1
if cnt>=2 or mx[-1]!=dd[j]:
_tree_dp(j,max(mx[-1],p)+1)
elif len(mx)!=1:
_tree_dp(j,max(mx[-2],p)+1)
else:
_tree_dp(j,p+1)
return size[v]
_tree_dp(s,0)
return size
treedp2(path,0)
print(ans)
もの