結果
問題 |
No.3206 う し た ウ ニ 木 あ く ん 笑
|
ユーザー |
![]() |
提出日時 | 2025-07-18 22:50:31 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,928 ms / 3,000 ms |
コード長 | 6,491 bytes |
コンパイル時間 | 414 ms |
コンパイル使用メモリ | 82,360 KB |
実行使用メモリ | 451,316 KB |
最終ジャッジ日時 | 2025-08-08 12:20:39 |
合計ジャッジ時間 | 23,009 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 30 |
ソースコード
############################################################################################ def factorization(n): arr = [] temp = n for i in range(2, int(-(-n**0.5//1))+1): if temp%i==0: cnt=0 while temp%i==0: cnt+=1 temp //= i arr.append([i, cnt]) if temp!=1: arr.append([temp, 1]) if arr==[]: arr.append([n, 1]) return arr def cycle_detection(f, x0): power = lam = 1 first = x0 second = f(x0) while first != second: if power == lam: first = second power *= 2 lam = 0 second = f(second) lam += 1 first = second = x0 for i in range(lam): second = f(second) mu = 0 while first != second: first = f(first) second = f(second) mu += 1 return mu, lam def cycle_detection_lists(f, x0): mu, lam = cycle_detection(f, x0) before_cycle = [0] * mu if mu > 0: before_cycle[0] = x0 for i in range(mu - 1): before_cycle[i + 1] = f(before_cycle[i]) cycle = [0] * lam cycle[0] = before_cycle[-1] if mu > 0 else x0 for i in range(lam - 1): cycle[i + 1] = f(cycle[i]) return before_cycle, cycle import bisect,collections,copy,heapq,itertools,math,string,sys,queue,time,random from decimal import Decimal def I(): return input() def IS(): return input().split() def II(): return int(input()) def IIS(): return list(map(int,input().split())) def LIIS(): return list(map(int,input().split())) def comb(n, r):return math.factorial(n) // (math.factorial(n - r) * math.factorial(r)) def make_divisors(n): lower_divisors , upper_divisors = [], [] i = 1 while i*i <= n: if n % i == 0: lower_divisors.append(i) if i != n // i: upper_divisors.append(n//i) i += 1 return lower_divisors + upper_divisors[::-1] INF=float("inf") MOD=998244353 MOD2=10**9+7 sys.setrecursionlimit(3*10**5) alpha="ABCDEFGHIJKLMNOPQRSTUVWXYZ" def bit_count(x): return bin(x).count("1") def yesno(f): if f:print("Yes") else:print("No") # import pypyjit # pypyjit.set_param('max_unroll_recursion=-1') def prime_factorization(n): factors = [] temp = n for i in range(2, int(-(-n**0.5//1)) + 1): if temp % i == 0: count = 0 while temp % i == 0: count += 1 temp //= i factors.append((i, count)) if temp != 1: factors.append((temp, 1)) if not factors: factors.append((n, 1)) return factors #n進数表示 def to_base(n, base): if n == 0: return [0] digits = [] while n: digits.append(n % base) n //= base return digits[::-1] class ncr_object: def __init__(self, n, mod=MOD): self.n = n self.mod = mod self.fact = [1] * (n + 1) for i in range(2, n + 1): self.fact[i] = self.fact[i - 1] * i % mod def comb(self, n, r): if n < r or n < 0 or r < 0: return 0 return self.fact[n] * pow(self.fact[r], self.mod - 2, self.mod) * pow(self.fact[n - r], self.mod - 2, self.mod) % self.mod def _ntt(a, invert, mod, root): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j & bit: j ^= bit bit >>= 1 j |= bit if i < j: a[i], a[j] = a[j], a[i] length = 2 while length <= n: # wlen = root^((mod-1)/length) mod mod exp = (mod - 1) // length wlen = pow(root, exp, mod) if invert: wlen = pow(wlen, mod - 2, mod) for i in range(0, n, length): w = 1 for k in range(i, i + length // 2): u = a[k] v = a[k + length // 2] * w % mod a[k] = (u + v) % mod a[k + length // 2] = (u - v + mod) % mod w = w * wlen % mod length <<= 1 if invert: inv_n = pow(n, mod - 2, mod) for i in range(n): a[i] = a[i] * inv_n % mod def convolution_ntt(a, b, mod=998244353, root=3): """ Perform convolution of lists a and b under modulus `mod` using NTT. `root` is a primitive root of `mod`, where mod = k*2^m + 1. Returns a list of size len(a)+len(b)-1, results mod `mod`. """ n = 1 while n < len(a) + len(b) - 1: n <<= 1 fa = a + [0] * (n - len(a)) fb = b + [0] * (n - len(b)) _ntt(fa, invert=False, mod=mod, root=root) _ntt(fb, invert=False, mod=mod, root=root) for i in range(n): fa[i] = fa[i] * fb[i] % mod _ntt(fa, invert=True, mod=mod, root=root) return fa[:len(a) + len(b) - 1] # print(convolution_ntt([1, 2, 3], [4, 5, 6], mod=998244353, root=3)) # pragma GCC target("avx2") # pragma GCC optimize("O3") # pragma GCC optimize("unroll-loops") from fractions import Fraction as fraction #################################################### def treedp(path,s): n=len(path) used=[0 for i in range(len(path))] used[s]=1 size=[0 for i in range(len(path))] def _tree_dp(v,d): res=d for j in path[v]: if used[j]:continue used[j]=1 res=max(_tree_dp(j,d+1),res) size[v]=res-d return res _tree_dp(s,1) return size n=II() path=[[] for _ in range(n)] for _ in range(n-1): u,v=IIS() u-=1 v-=1 path[u].append(v) path[v].append(u) dd=treedp(path,0) ans=0 def treedp2(path,s): n=len(path) used=[0 for i in range(len(path))] size=[1 for i in range(len(path))] used[s]=1 def _tree_dp(v,p): cnt=0 mx=[dd[j] for j in path[v] if not used[j]] mx.sort() mn=list([p]+[dd[j]+1 for j in path[v] if not used[j]]) mn.sort(reverse=True) global ans for i in range(len(mn)): ans=max(ans,mn[i]*(i+1)+1) for j in path[v]: if used[j]:continue # used[j]=1 cnt+=mx[-1]==dd[j] for j in path[v]: if used[j]:continue used[j]=1 if cnt>=2 or mx[-1]!=dd[j]: _tree_dp(j,max(mx[-1]+1,p)+1) elif len(mx)!=1: _tree_dp(j,max(mx[-2]+1,p)+1) else: _tree_dp(j,p+1) return size[v] _tree_dp(s,0) return size treedp2(path,0) print(ans)