結果
問題 | No.408 五輪ピック |
ユーザー | lif4635 |
提出日時 | 2024-09-09 20:25:44 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
AC
|
実行時間 | 377 ms / 5,000 ms |
コード長 | 57,968 bytes |
コンパイル時間 | 547 ms |
コンパイル使用メモリ | 18,304 KB |
実行使用メモリ | 36,356 KB |
最終ジャッジ日時 | 2024-09-09 20:25:53 |
合計ジャッジ時間 | 8,369 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 67 ms
16,896 KB |
testcase_01 | AC | 67 ms
16,896 KB |
testcase_02 | AC | 68 ms
17,152 KB |
testcase_03 | AC | 69 ms
16,896 KB |
testcase_04 | AC | 67 ms
16,896 KB |
testcase_05 | AC | 299 ms
31,704 KB |
testcase_06 | AC | 200 ms
28,376 KB |
testcase_07 | AC | 260 ms
31,780 KB |
testcase_08 | AC | 232 ms
29,084 KB |
testcase_09 | AC | 230 ms
27,552 KB |
testcase_10 | AC | 198 ms
27,380 KB |
testcase_11 | AC | 191 ms
26,880 KB |
testcase_12 | AC | 294 ms
33,156 KB |
testcase_13 | AC | 306 ms
31,864 KB |
testcase_14 | AC | 110 ms
20,352 KB |
testcase_15 | AC | 291 ms
27,676 KB |
testcase_16 | AC | 295 ms
29,012 KB |
testcase_17 | AC | 312 ms
35,344 KB |
testcase_18 | AC | 319 ms
32,028 KB |
testcase_19 | AC | 316 ms
34,496 KB |
testcase_20 | AC | 132 ms
22,144 KB |
testcase_21 | AC | 106 ms
19,968 KB |
testcase_22 | AC | 183 ms
25,600 KB |
testcase_23 | AC | 316 ms
33,144 KB |
testcase_24 | AC | 225 ms
27,636 KB |
testcase_25 | AC | 292 ms
28,992 KB |
testcase_26 | AC | 377 ms
36,356 KB |
testcase_27 | AC | 325 ms
28,564 KB |
testcase_28 | AC | 206 ms
26,232 KB |
testcase_29 | AC | 208 ms
26,348 KB |
testcase_30 | AC | 68 ms
16,896 KB |
testcase_31 | AC | 67 ms
17,024 KB |
ソースコード
def main(): n,m = MI() edge = [set() for i in range(n)] d1 = set() e = set() for i in range(m): a,b = MI_1() if a == 0 or b == 0: d1.add(a+b) else: edge[a].add(b) edge[b].add(a) e.add((a,b)) d2 = [[-1,-1] for i in range(n)] for i in range(n): for k in edge[i]: if k in d1: if d2[i][0] == -1: d2[i][0] = k else: d2[i][1] = k for a,b in e: for i in d2[a]: for j in d2[b]: if i == -1 or j == -1: continue if i != j and i != b and j != a: print("YES") exit() print("NO") pass """==================fold line==================""" """import""" # array from bisect import bisect_left,bisect_right from heapq import heapify,heappop,heappush from collections import deque,defaultdict,Counter # math import math,random,cmath from math import comb,ceil,floor,factorial,gcd,lcm,atan2,sqrt,isqrt,pi,e from itertools import product,permutations,combinations,accumulate from functools import cmp_to_key, cache # system from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar, Optional T = TypeVar('T') import sys sys.setrecursionlimit(10**9) """input""" #int-input def II() -> int : return int(input()) def MI() -> int : return map(int, input().split()) def TI() -> tuple[int] : return tuple(MI()) def LI() -> list[int] : return list(MI()) #str-input def SI() -> str : return input() def MSI() -> str : return input().split() def SI_L() -> list[str] : return list(SI()) def SI_LI() -> list[int] : return list(map(int, SI())) #multiple-input def LLI(n) -> list[list[int]]: return [LI() for _ in range(n)] def LSI(n) -> list[str]: return [SI() for _ in range(n)] #1-indexを0-indexでinput def MI_1() -> int : return map(lambda x:int(x)-1, input().split()) def TI_1() -> tuple[int] : return tuple(MI_1()) def LI_1() -> list[int] : return list(MI_1()) def ordalp(s:str) -> int|list[int]: if len(s) == 1: return ord(s)-ord("A") if s.isupper() else ord(s)-ord("a") return list(map(lambda i: ord(i)-ord("A") if i.isupper() else ord(i)-ord("a"), s)) def ordallalp(s:str) -> int|list[int]: if len(s) == 1: return ord(s)-ord("A")+26 if s.isupper() else ord(s)-ord("a") return list(map(lambda i: ord(i)-ord("A")+26 if i.isupper() else ord(i)-ord("a"), s)) def graph(n:str, m:str, dir:bool=False , index=-1) -> list[set[int]]: """ (頂点,辺,有向か,indexの調整) defaultは無向辺、(index)-1 """ edge = [set() for i in range(n+1+index)] for _ in range(m): a,b = map(int, input().split()) a,b = a+index,b+index edge[a].add(b) if not dir: edge[b].add(a) return edge def graph_w(n:str, m:str, dir:bool=False , index=-1) -> list[set[tuple]]: """ (頂点,辺,有向か,indexの調整) defaultは無向辺、index-1 """ edge = [set() for i in range(n+1+index)] for _ in range(m): a,b,c = map(int, input().split()) a,b = a+index,b+index edge[a].add((b,c)) if not dir: edge[b].add((a,c)) return edge """const""" mod, inf = 998244353, 1<<60 isnum = {int,float,complex} true, false, none = True, False, None def yes() -> None: print("Yes") def no() -> None: print("No") def yn(flag:bool) -> None: print("Yes" if flag else "No") def ta(flag:bool) -> None: print("Takahashi" if flag else "Aoki") alplow = "abcdefghijklmnopqrstuvwxyz" alpup = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" alpall = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" URDL = {'U':(-1,0), 'R':(0,1), 'D':(1,0), 'L':(0,-1)} DIR_4 = [[-1,0],[0,1],[1,0],[0,-1]] DIR_8 = [[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1],[0,-1],[-1,-1]] DIR_BISHOP = [[-1,1],[1,1],[1,-1],[-1,-1]] prime60 = [2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59] # alias DD = defaultdict BSL = bisect_left BSR = bisect_right """math fanctions""" """point""" cross_pro = lambda p1,p2 : p2[0]*p1[1] - p2[1]*p1[0] #外積 dist = lambda p1,p2 : sqrt(pow(p1[0]-p2[0],2) + pow(p1[1]-p2[1],2)) def max_min_cross(p1, p2, p3, p4, touch = False): min_ab, max_ab = min(p1, p2), max(p1, p2) min_cd, max_cd = min(p3, p4), max(p3, p4) if touch: if min_ab > max_cd or max_ab < min_cd: return False return True else: if min_ab >= max_cd or max_ab <= min_cd: return False return True def cross_judge(a, b, c, d, touch = False): """線分abとcdの交差判定 接するも含むかどうか""" # x座標による判定 if not max_min_cross(a[0], b[0], c[0], d[0], touch): return False # y座標による判定 if not max_min_cross(a[1], b[1], c[1], d[1], touch): return False tc1 = (a[0] - b[0]) * (c[1] - a[1]) + (a[1] - b[1]) * (a[0] - c[0]) tc2 = (a[0] - b[0]) * (d[1] - a[1]) + (a[1] - b[1]) * (a[0] - d[0]) td1 = (c[0] - d[0]) * (a[1] - c[1]) + (c[1] - d[1]) * (c[0] - a[0]) td2 = (c[0] - d[0]) * (b[1] - c[1]) + (c[1] - d[1]) * (c[0] - b[0]) if touch: return tc1 * tc2 <= 0 and td1 * td2 <= 0 else: return tc1 * tc2 < 0 and td1 * td2 < 0 """primary function""" def prod(lst:list[int|str], mod = None) -> int|str: """product 文字列の場合連結""" ans = 1 if type(lst[0]) in isnum: for i in lst: ans *= i if mod: ans %= mod return ans else: return "".join(lst) def sigma(first:int, diff:int, term:int) -> int: #等差数列の和 return term*(first*2+(term-1)*diff)//2 def xgcd(a:int, b:int) -> tuple[int,int,int]: #Euclid互除法 """ans = a*x0 + b*y0""" x0, y0, x1, y1 = 1, 0, 0, 1 while b != 0: q, a, b = a // b, b, a % b x0, x1 = x1, x0 - q * x1 y0, y1 = y1, y0 - q * y1 return a, x0, y0 def modinv(a:int, mod = mod) -> int: #逆元 """逆元""" g, x, y = xgcd(a, mod) #g != 1は逆元が存在しない return -1 if g != 1 else x % m def nth_root(x:int, n:int, is_x_within_64bit = True) -> int: #n乗根 """floor(n√x)""" ngs = [-1, -1, 4294967296, 2642246, 65536, 7132, 1626, 566, 256, 139, 85, 57, 41, 31, 24, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] if x <= 1 or n == 1: return x if is_x_within_64bit: if n >= 64: return 1 ng = ngs[n] else: ng = x ok = 0 while abs(ok - ng) > 1: mid = (ok + ng)//2 if mid**n <= x: ok = mid else: ng = mid return ok def pi_base(p:list) -> Iterator: l = len(p) num = [0]*l while True: yield num num[~0] += 1 for i in range(l): if num[~i] == p[~i]: if i == l-1: return num[~i] = 0 num[~(i+1)] += 1 else: break """prime""" def primefact(n:int) -> dict[int,int]: #素因数分解 """素因数分解""" i = 2 pdict = dict() while i*i <= n: if n%i == 0: cnt = 0 while n%i == 0: n //= i cnt += 1 pdict[i] = cnt i += 1 if n != 1: pdict[n] = 1 return pdict def primenumber(lim:int, get = None) -> list[int]: #素数列挙 """ 素数列挙 sieve(n)もあります get == None : リスト get >= 1 : flag get < 1 : 累積和 """ lim += 1 #素数にはflagを立てる p = [1]*lim #それ以下の素数の数を保管 cntp = [0]*lim #素数列を格納 plist = [] p[0],p[1] = 0,0 for i in range(2,lim): if p[i]: plist.append(i) for j in range(2*i,lim,i): p[j] = 0 for i in range(1,lim): cntp[i] = cntp[i-1] + p[i] if get is None: return plist elif get >= 1: return p else: return cntp def divisors(n:int) -> list[int] : #約数列挙 """約数列挙""" divs_small, divs_big = [], [] i = 1 while i * i <= n: if n % i == 0: divs_small.append(i) divs_big.append(n // i) i += 1 if divs_big[-1] == divs_small[-1]: divs_big.pop() for e in reversed(divs_big): divs_small.append(e) return divs_small """binary number""" lenbit = lambda bit: (bit).bit_length() def popcnt(n:int) -> int: #popcnt """int.bit_count() があります 64bitまで""" c=(n&0x5555555555555555)+((n>>1)&0x5555555555555555) c=(c&0x3333333333333333)+((c>>2)&0x3333333333333333) c=(c&0x0f0f0f0f0f0f0f0f)+((c>>4)&0x0f0f0f0f0f0f0f0f) c=(c&0x00ff00ff00ff00ff)+((c>>8)&0x00ff00ff00ff00ff) c=(c&0x0000ffff0000ffff)+((c>>16)&0x0000ffff0000ffff) c=(c&0x00000000ffffffff)+((c>>32)&0x00000000ffffffff) return c def binchange(n:int,fill0 = None) -> str: """10進数(int)→2進数(str) fill0:0埋め桁数""" return format(n, "0"+str(fill0)+"b") if fill0 else format(n,"b") """list""" def prefix_op(lst:list, op = lambda x,y:x+y, e = 0) -> list: #累積和 """defaultは累積和""" n = len(lst) res = [e]*(n+1) for i in range(n): res[i+1] = op(res[i], lst[i]) return res def suffix_op(lst:list, op = lambda x,y:x+y, e = 0) -> list: #累積和 """defaultは累積和""" n = len(lst) res = [e]*(n+1) for i in reversed(range(n)): res[i] = op(res[i+1], lst[i]) return res def mex(lst:list) -> int: """補集合の最小非負整数""" l = set(lst) ans = 0 while ans in l: ans += 1 return ans def inversion_cnt(lst:list, flag = None) -> int: #転倒数 """転倒数 not順列→flag立てる""" n = len(lst) if not flag is None: comp = Compress(lst) lst = comp.comp else: lst = list(map(lambda x : x-1, lst)) ft = FenwickTree(n) ans = [0]*n #i要素目の転倒への寄与 for i in range(n): ans[i] = ft.sum(lst[i]+1,n) ft.add(lst[i], 1) return ans def doubling(nex:list, k:int = None ,a:list = None) -> list: """nex:操作列 k:回数 a:初期列""" n = len(nex) if k is None: log = 60 else: log = (k+1).bit_length() res = [[-1]*n for _ in range(log)] #ダブリング配列 res[0] = nex[:] for cnt in range(1,log): for i in range(n): tmp = res[cnt-1][i] res[cnt][i] = res[cnt-1][tmp] if k is None: return res ans = (nex[:] if a is None else a[:]) for cnt in range(log): if k & (1<<cnt) != 0: ans = [ans[res[cnt][i]] for i in range(n)] return ans def swapcnt(a:list, b:list) -> int: """ 順列(同じ要素がない)が前提 最小操作回数を返す """ if sorted(a) != sorted(b): return -1 t = dict() cnt = 0 for i in range(n): x,y = a[i],b[i] if x == y: continue if x in t: while x in t: x_ = t[x] del t[x] x = x_ cnt += 1 if x == y: break else: t[y] = x else: t[y] = x return cnt """matrix""" def mul_matrix(A, B, mod = mod): #行列の積 A*B N = len(A) K = len(A[0]) M = len(B[0]) res = [[0 for _ in range(M)] for _ in range(N)] for i in range(N) : for j in range(K) : for k in range(M) : res[i][k] += A[i][j] * B[j][k] res[i][k] %= mod return res def pow_matrix(mat, exp, mod = mod): #二分累乗 N = len(mat) res = [[1 if i == j else 0 for i in range(N)] for j in range(N)] while exp > 0 : if exp%2 == 1 : res = mul_matrix(res, mat, mod) mat = mul_matrix(mat, mat, mod) exp //= 2 return res """enumerate""" def fact_enu(lim): #階乗列挙 #階乗 fac = [1] #階乗の逆数 divfac = [1] factorial = 1 for i in range(1,lim+1): factorial *= i factorial %= mod fac.append(factorial) divfac.append(pow(factorial,-1,mod)) return fac,divfac class Comb_enu: #combination列挙 def __init__(self,lim,mod = mod): """ mod : prime指定 lim以下のmodでcomdination計算 """ self.fac = [1,1] self.inv = [1,1] self.finv = [1,1] self.mod = mod for i in range(2,lim+1): self.fac.append(self.fac[i-1]*i%mod) self.inv.append(-self.inv[mod%i]*(mod//i)%mod) self.finv.append(self.finv[i-1]*self.inv[i]%mod) def comb(self,a,b): if a < b: return 0 if a < 0: return 0 return self.fac[a]*self.finv[b]*self.finv[a-b]%self.mod """str""" def int_0(str,l,r = None, over_ok = False): #str→int """ strの[l,r)桁をintで返す(0-index) 取れない場合はNone over_okを立てればrが桁を超えても返す """ lstr = len(str) if l > len(str): return None l = lstr - l if r == None: if "" == str[r:l]: return 0 return int(str[:l]) if r > len(str): if over_ok: return int(str[:l]) else: return None r = lstr - r if "" == str[r:l]: return 0 return int(str[r:l]) def lis(l): #後でちゃんと書き直してね # STEP1: LIS長パート with 使用位置 n = len(l) lisDP = [inf] * n # いまi文字目に使っている文字 indexList = [None] * n # lの[i]文字目が使われた場所を記録する for i in range(n): # 通常のLISを求め、indexListに使った場所を記録する ind = bisect_left(lisDP, l[i]) lisDP[ind] = l[i] indexList[i] = ind # STEP2: LIS復元パート by 元配列の使用した位置 # 後ろから見ていくので、まずは、LIS長目(targetIndex)のindexListを探したいとする targetIndex = max(indexList) ans = [0] * (targetIndex + 1) # 復元結果(indexListは0-indexedなのでlen=4ならmax=3で格納されているので+1する) # 後ろから見ていく for i in range(n - 1, -1, -1): # もし、一番最後に出てきているtargetIndexなら if indexList[i] == targetIndex: ans[targetIndex] = l[i] # ansのtargetIndexを確定 targetIndex -= 1 return ans """table operation""" def acc_sum(lst:list, dim = 2) -> list: if dim == 2: h,w = len(lst),len(lst[0]) res = [[0]*(w+1)] for i in range(h): res.append([0]) for j in range(w): res[-1].append(res[i+1][j] + lst[i][j]) for j in range(w): for i in range(h): res[i+1][j+1] += res[i][j+1] return res elif dim == 3: d1,d2,d3 = len(lst),len(lst[0]),len(lst[0][0]) res = [[[0]*(d3+1) for i in range(d2+1)]] for i in range(d1): res.append([[0]*(d3+1)]) for j in range(d2): res[-1].append([0]) for k in range(d3): res[-1][-1].append(res[i+1][j+1][k] + lst[i][j][k]) for j in range(d2): for k in range(d3): for i in range(d1): res[i+1][j+1][k+1] += res[i][j+1][k+1] for k in range(d3): for i in range(d1): for j in range(d2): res[i+1][j+1][k+1] += res[i+1][j][k+1] return res def copy_table(table): H,W = len(table), len(table[0]) res = [] for i in range(H): res.append([]) for j in range(W): res[-1].append(table[i][j]) return res def rotate_table(table): #反時計回りに回転 return list(map(list, zip(*table)))[::-1] def transpose_table(l): #行と列を入れ替え return [list(x) for x in zip(*l)] def bitconvert_table(table, letter1="#", rev=False): #各行bitに変換 H,W = len(table), len(table[0]) res = [] for h in range(H): rowBit = 0 for w in range(W): if rev: if table[h][w] == letter1: rowBit += 1<<w else: if table[h][W-w-1] == letter1: rowBit += 1<<w res.append(rowBit) return res """sort""" def argment_sort(points): #偏角ソート yposi,ynega = [],[] for x,y in points: if y > 0 or (y == 0 and x >= 0): yposi.append([x,y]) else: ynega.append([x,y]) yposi.sort(key = cmp_to_key(cross_pro)) ynega.sort(key = cmp_to_key(cross_pro)) return yposi+ynega def quick_sort(lst, comparision, left = 0, right = -1): """ (list,比較関数,(l),(r)) input : (p,q) output : True (p<q) """ i = left if right == -1: right %= len(lst) j = right pivot = (i+j)//2 dpivot = lst[pivot] while True: #条件式 while comparision(lst[i],dpivot): i += 1 while comparision(dpivot,lst[j]): j -= 1 if i >= j: break lst[i],lst[j] = lst[j],lst[i] i += 1 j -= 1 if left < i - 1: quick_sort(lst, left, i - 1) if right > j + 1: quick_sort(lst, j + 1, right) def bubble_sort(lst): """返り値:転倒数""" cnt = 0 n = len(lst) for i in range(n): for j in reversed(range(i+1),n): if a[j] > a[j-1]: a[j],a[j-1] = a[j-1],a[j] cnt += 1 return cnt def topological_sort(egde, inedge=None): n = len(edge) if inedge == None: inedge = [0]*n for v in range(n): for adj in edge(v): inedge[adj] += 1 ans = [i for i in range(n) if inedge[i] == 0] que = deque(ans) while que: q = que.popleft() for e in edge[q]: inedge[e] -= 1 if inedge[e] == 0: que.append(e) ans.append(e) return ans """graph fanctions""" def dijkstra(edge, start=0, goal=None): """計算量 O((node+edge)log(edge))""" n = len(edge) dis = [inf]*n dis[start] = 0 que = [(0, start)] heapify(que) while que: cur_dis,cur_node = heappop(que) if dis[cur_node] < cur_dis: continue for next_node, weight in edge[cur_node]: next_dis = cur_dis + weight if next_dis < dis[next_node]: dis[next_node] = next_dis heappush(que, (next_dis, next_node)) if goal != None: return dis[goal] return dis def warshallfloyd(dis): n = len(dis) for i in range(n): dis[i][i] = 0 for k in range(n): for i in range(n): for j in range(n): dis[i][j] = min(dis[i][j], dis[i][k]+dis[k][j]) return dis def bellmanford(edge, start=0, goal=None): """ 始点と終点が決まっている 始点から到達可能かつ、終点に到達可能な閉路のみ検出 """ n = len(edge) dis = [inf]*n pre = [-1]*n #最短経路における直前にいた頂点 negative = [False]*n #たどり着くときに負の閉路があるかどうか dis[start] = 0 for t in range(2*n): for u in range(n): for v, cost in edge[u]: if dis[v] > dis[u] + cost: if t >= n-1 and v == goal: return None #0と衝突しないように elif i >= n-1: dis[v] = -inf else: dis[v] = dis[u] + cost pre[v] = u return dis[goal] #通常はここで終わり #最短経路の復元 x = goal path = [x] while x != start: x = pre[x] path.append(x) #最短経路を含む負の閉路があるかどうか for i in reversed(range(len(path)-1)): u, v = path[i+1], path[i] if dis[v] > dis[u] + cost: dis[v] = dis[u] + cost negative[v] = True if negative[u]: negative[v] = True if negative[end]: return -1 else: return d[end] #ループ検出書くの嫌いなので用意しましょう def loop(g): pass """data stucture""" #双方向リスト # https://github.com/tatyam-prime/SortedSet?tab=readme-ov-file class SortedSet(Generic[T]): BUCKET_RATIO = 16 SPLIT_RATIO = 24 def __init__(self, a: Iterable[T] = []) -> None: "Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)" a = list(a) n = len(a) if any(a[i] > a[i + 1] for i in range(n - 1)): a.sort() if any(a[i] >= a[i + 1] for i in range(n - 1)): a, b = [], a for x in b: if not a or a[-1] != x: a.append(x) n = self.size = len(a) num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO))) self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)] def __iter__(self) -> Iterator[T]: for i in self.a: for j in i: yield j def __reversed__(self) -> Iterator[T]: for i in reversed(self.a): for j in reversed(i): yield j def __eq__(self, other) -> bool: return list(self) == list(other) def __len__(self) -> int: return self.size def __repr__(self) -> str: return "SortedSet" + str(self.a) def __str__(self) -> str: s = str(list(self)) return "{" + s[1 : len(s) - 1] + "}" def _position(self, x: T) -> Tuple[List[T], int, int]: "return the bucket, index of the bucket and position in which x should be. self must not be empty." for i, a in enumerate(self.a): if x <= a[-1]: break return (a, i, bisect_left(a, x)) def __contains__(self, x: T) -> bool: if self.size == 0: return False a, _, i = self._position(x) return i != len(a) and a[i] == x def add(self, x: T) -> bool: "Add an element and return True if added. / O(√N)" if self.size == 0: self.a = [[x]] self.size = 1 return True a, b, i = self._position(x) if i != len(a) and a[i] == x: return False a.insert(i, x) self.size += 1 if len(a) > len(self.a) * self.SPLIT_RATIO: mid = len(a) >> 1 self.a[b:b+1] = [a[:mid], a[mid:]] return True def _pop(self, a: List[T], b: int, i: int) -> T: ans = a.pop(i) self.size -= 1 if not a: del self.a[b] return ans def discard(self, x: T) -> bool: "Remove an element and return True if removed. / O(√N)" if self.size == 0: return False a, b, i = self._position(x) if i == len(a) or a[i] != x: return False self._pop(a, b, i) return True def lt(self, x: T) -> Optional[T]: "Find the largest element < x, or None if it doesn't exist." for a in reversed(self.a): if a[0] < x: return a[bisect_left(a, x) - 1] def le(self, x: T) -> Optional[T]: "Find the largest element <= x, or None if it doesn't exist." for a in reversed(self.a): if a[0] <= x: return a[bisect_right(a, x) - 1] def gt(self, x: T) -> Optional[T]: "Find the smallest element > x, or None if it doesn't exist." for a in self.a: if a[-1] > x: return a[bisect_right(a, x)] def ge(self, x: T) -> Optional[T]: "Find the smallest element >= x, or None if it doesn't exist." for a in self.a: if a[-1] >= x: return a[bisect_left(a, x)] def __getitem__(self, i: int) -> T: "Return the i-th element." if i < 0: for a in reversed(self.a): i += len(a) if i >= 0: return a[i] else: for a in self.a: if i < len(a): return a[i] i -= len(a) raise IndexError def pop(self, i: int = -1) -> T: "Pop and return the i-th element." if i < 0: for b, a in enumerate(reversed(self.a)): i += len(a) if i >= 0: return self._pop(a, ~b, i) else: for b, a in enumerate(self.a): if i < len(a): return self._pop(a, b, i) i -= len(a) raise IndexError def index(self, x: T) -> int: "Count the number of elements < x." ans = 0 for a in self.a: if a[-1] >= x: return ans + bisect_left(a, x) ans += len(a) return ans def index_right(self, x: T) -> int: "Count the number of elements <= x." ans = 0 for a in self.a: if a[-1] > x: return ans + bisect_right(a, x) ans += len(a) return ans class SortedList(Generic[T]): BUCKET_RATIO = 16 SPLIT_RATIO = 24 def __init__(self, a: Iterable[T] = []) -> None: "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)" a = list(a) n = self.size = len(a) if any(a[i] > a[i + 1] for i in range(n - 1)): a.sort() num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO))) self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)] def __iter__(self) -> Iterator[T]: for i in self.a: for j in i: yield j def __reversed__(self) -> Iterator[T]: for i in reversed(self.a): for j in reversed(i): yield j def __eq__(self, other) -> bool: return list(self) == list(other) def __len__(self) -> int: return self.size def __repr__(self) -> str: return "SortedMultiset" + str(self.a) def __str__(self) -> str: s = str(list(self)) return "{" + s[1 : len(s) - 1] + "}" def _position(self, x: T) -> Tuple[List[T], int, int]: "return the bucket, index of the bucket and position in which x should be. self must not be empty." for i, a in enumerate(self.a): if x <= a[-1]: break return (a, i, bisect_left(a, x)) def __contains__(self, x: T) -> bool: if self.size == 0: return False a, _, i = self._position(x) return i != len(a) and a[i] == x def count(self, x: T) -> int: "Count the number of x." return self.index_right(x) - self.index(x) def add(self, x: T) -> None: "Add an element. / O(√N)" if self.size == 0: self.a = [[x]] self.size = 1 return a, b, i = self._position(x) a.insert(i, x) self.size += 1 if len(a) > len(self.a) * self.SPLIT_RATIO: mid = len(a) >> 1 self.a[b:b+1] = [a[:mid], a[mid:]] def _pop(self, a: List[T], b: int, i: int) -> T: ans = a.pop(i) self.size -= 1 if not a: del self.a[b] return ans def discard(self, x: T) -> bool: "Remove an element and return True if removed. / O(√N)" if self.size == 0: return False a, b, i = self._position(x) if i == len(a) or a[i] != x: return False self._pop(a, b, i) return True def lt(self, x: T) -> Optional[T]: "Find the largest element < x, or None if it doesn't exist." for a in reversed(self.a): if a[0] < x: return a[bisect_left(a, x) - 1] def le(self, x: T) -> Optional[T]: "Find the largest element <= x, or None if it doesn't exist." for a in reversed(self.a): if a[0] <= x: return a[bisect_right(a, x) - 1] def gt(self, x: T) -> Optional[T]: "Find the smallest element > x, or None if it doesn't exist." for a in self.a: if a[-1] > x: return a[bisect_right(a, x)] def ge(self, x: T) -> Optional[T]: "Find the smallest element >= x, or None if it doesn't exist." for a in self.a: if a[-1] >= x: return a[bisect_left(a, x)] def __getitem__(self, i: int) -> T: "Return the i-th element." if i < 0: for a in reversed(self.a): i += len(a) if i >= 0: return a[i] else: for a in self.a: if i < len(a): return a[i] i -= len(a) raise IndexError def pop(self, i: int = -1) -> T: "Pop and return the i-th element." if i < 0: for b, a in enumerate(reversed(self.a)): i += len(a) if i >= 0: return self._pop(a, ~b, i) else: for b, a in enumerate(self.a): if i < len(a): return self._pop(a, b, i) i -= len(a) raise IndexError def index(self, x: T) -> int: "Count the number of elements < x." ans = 0 for a in self.a: if a[-1] >= x: return ans + bisect_left(a, x) ans += len(a) return ans def index_right(self, x: T) -> int: "Count the number of elements <= x." ans = 0 for a in self.a: if a[-1] > x: return ans + bisect_right(a, x) ans += len(a) return ans class Deque: #両端以外もO(1)でアクセスできるdeque def __init__(self, src_arr=[], max_size=300000): self.N = max(max_size, len(src_arr)) + 1 self.buf = list(src_arr) + [None] * (self.N - len(src_arr)) self.head = 0 self.tail = len(src_arr) def __index(self, i): l = len(self) if not -l <= i < l: raise IndexError('index out of range: ' + str(i)) if i < 0: i += l return (self.head + i) % self.N def __extend(self): ex = self.N - 1 self.buf[self.tail+1 : self.tail+1] = [None] * ex self.N = len(self.buf) if self.head > 0: self.head += ex def is_full(self): return len(self) >= self.N - 1 def is_empty(self): return len(self) == 0 def append(self, x): if self.is_full(): self.__extend() self.buf[self.tail] = x self.tail += 1 self.tail %= self.N def appendleft(self, x): if self.is_full(): self.__extend() self.buf[(self.head - 1) % self.N] = x self.head -= 1 self.head %= self.N def pop(self): if self.is_empty(): raise IndexError('pop() when buffer is empty') ret = self.buf[(self.tail - 1) % self.N] self.tail -= 1 self.tail %= self.N return ret def popleft(self): if self.is_empty(): raise IndexError('popleft() when buffer is empty') ret = self.buf[self.head] self.head += 1 self.head %= self.N return ret def __len__(self): return (self.tail - self.head) % self.N def __getitem__(self, key): return self.buf[self.__index(key)] def __setitem__(self, key, value): self.buf[self.__index(key)] = value def __str__(self): return 'Deque({0})'.format(str(list(self))) class WeightedUnionFind: #重み付きunion-find def __init__(self, N): self.N = N self.parents = [-1] * N self.rank = [0] * N self.weight = [0] * N def root(self, x): if self.parents[x] == -1: return x rx = self.root(self.parents[x]) self.weight[x] += self.weight[self.parents[x]] self.parents[x] = rx return self.parents[x] def get_weight(self, x): self.root(x) return self.weight[x] def unite(self, x, y, d): ''' A[x] - A[y] = d ''' w = d + self.get_weight(x) - self.get_weight(y) rx = self.root(x) ry = self.root(y) if rx == ry: _, d_xy = self.diff(x, y) if d_xy == d: return True else: return False if self.rank[rx] < self.rank[ry]: rx, ry = ry, rx w = -w if self.rank[rx] == self.rank[ry]: self.rank[rx] += 1 self.parents[ry] = rx self.weight[ry] = w return True def is_same(self, x, y): return self.root(x) == self.root(y) def diff(self, x, y): if self.is_same(x, y): return True, self.get_weight(y) - self.get_weight(x) else: return False, 0 """binary search""" def bi_int(comparison, ok = 0, ng = inf): """ [lowlim,ans)だとTrueで[ans,uplim)だとFalse のイメージで実装 """ if not pred(ok): #条件を満たすことがない return ok while abs(ng - ok) > 1: mid = ok + (ng - ok)//2 if comparison(mid): ok = mid else: ng = mid return ok def bi_float(comparison, ok = 0, ng = inf, error = 10**(-9)): """ [lowlim,ans)だとTrueで[ans,uplim)だとFalse のイメージで実装 """ if not pred(ok): #条件を満たすことがない return ok #相対誤差と絶対誤差のどちらかがerror以下で終了 while abs(ng - ok)/abs(ng) > error and abs(ng - ok) > eroor: mid = ok + (ng - ok)/2 if comparison(mid): ok = mid else: ng = mid return ok """compress""" class Compress: #座標圧縮(一次元) def __init__(self, arr): values = sorted(set(arr)) self.translator = dict([(values[i], i) for i in range(len(values))]) self.inv_translator = values self.comp = [] for x in arr: self.comp.append(self.translator[x]) #圧縮前→圧縮後 def to_comp(self, x): return self.translator[x] #圧縮後→圧縮前 def from_comp(self, v): return self.inv_translator[v] #lstを変換 def lst_comp(self, lst): return [self.to_comp(i) for i in lst] class Compress2D: #2次元リスト[x,y]の座標圧縮 def __init__(self, arr): self.x = Compress([x for x, y in arr]) self.y = Compress([y for x, y in arr]) self.comp = [] for x,y in arr: self.comp.append([self.x.translator[x],self.y.translator[y]]) #圧縮前→圧縮後 def to_comp(self, x): return (self.x.translator[x[0]], self.y.translator[x[1]]) #圧縮後→圧縮前 def from_comp(self, v): return (self.x.translator[v[0]], self.y.translator[v[1]]) class RollingHash: #hash化 def __init__(self, string, base = 37, mod = 10**9 + 9): self.mod = mod l = len(string) self.hash = [0]*(l+1) for i in range(1,l+1): self.hash[i] = ( self.hash[i-1] * base + ord(string[i-1]) ) % mod self.pw = [1]*(l+1) for i in range(1,l+1): self.pw[i] = self.pw[i-1] * base % mod def get(self, l, r): """s[l:r]のhash""" return (self.hash[r] - self.hash[l] * self.pw[r-l]) % self.mod class ZobristHash: #多重集合の一致判定 def __init__(self, n, as_list:bool = False, mod = (1<<61)-1): self.N = n self.conversion = [random.randint(1, mod - 1) for i in range(n+1)] self.as_list = as_list #setとして扱うかlistの並び替えか self.Mod = mod def makehash(self, a:list): la = len(a) hashlst = [0]*(la+1) if self.as_list: #listの並び替えとしての一致 for i in range(la): hashlst[i+1] = (hashlst[i]+self.conversion[a[i]])%self.Mod return hashlst else: #setとしての一致 cnt = {} for i in range(la): if a[i] in cnt: hashlst[i+1] = hashlst[i+1] continue cnt.add(a[i]) hashlst[i+1] = hashlst[i]^self.conversion[a[i]] return hashlst def get(self, hashedlst:list, l:int, r:int): """a[l:r]のhashを返します""" if self.as_list: return (hashedlst[r]-hashedlst[l])%self.Mod else: return hashedlst[r]^hashedlst[l] """畳み込み??""" """graph""" class GridSearch: def __init__(self, table): """盤面の受取""" self.table = table self.H = len(table) self.W = len(table[0]) self.wall = "#" self.dist = [[inf]*self.W for _ in range(self.H)] def find(self, c): """始点,終点等の取得""" for h in range(self.H): for w in range(self.W): if self.table[h][w] == c: return (h,w) return None def set_wall(self, string): """壁の設定""" self.wall = string def can_start(self, *start): """探索済みでないかつ壁でない""" if len(start) == 1: i,j = start[0][0],start[0][1] else: i,j = start[0],start[1] if self.dist[i][j] == inf and not self.table[i][j] in self.wall: return True else: return False def island(self, transition = DIR_4): """連結成分の検出""" H, W = self.H, self.W self.island_id = [[-1]*W for _ in range(H)] self.island_size = [[-1]*W for _ in range(H)] crr_id = 0 id2size = dict() for sh in range(H): for sw in range(W): if self.table[sh][sw] in self.wall: continue if self.island_id[sh][sw] != -1: continue deq = deque() deq.append((sh,sw)) crr_size = 1 self.island_id[sh][sw] = crr_id while deq: h,w = deq.popleft() for dh, dw in transition: nh, nw = h+dh, w+dw if (not 0<= nh < H) or (not 0 <= nw < W): continue if self.table[nh][nw] in self.wall: continue if self.island_id[nh][nw] == -1: self.island_id[nh][nw] = crr_id deq.append((nh, nw)) crr_size += 1 id2size[crr_id] = crr_size crr_id += 1 for h in range(H): for w in range(W): if self.table[h][w] in self.wall: continue self.island_size[h][w] = id2size[self.island_id[h][w]] return self.island_id, self.island_size def DFS(self, start, goal=None, transition = DIR_4): """ DFSをします input : (start,(goal),(transition)) output : dis(table) or goalまでのdis(int) """ H, W = self.H, self.W deq = deque() deq.append(start) self.dist[start[0]][start[1]] = 0 if start == goal: return 0 while deq: h,w = deq.popleft() for dh, dw in transition: nh = h+dh nw = w+dw # gridの範囲外. if (not 0 <= nh < H) or (not 0 <= nw < W): continue # wallに設定されている文字なら. if self.table[nh][nw] in self.wall: continue new_dist = self.dist[h][w] + 1 #goalが引数で与えられていてgoalに達したら終了. if goal and (nh,nw) == goal: return new_dist if self.dist[nh][nw] > new_dist: self.dist[nh][nw] = new_dist deq.append((nh,nw)) if goal: return -1 return self.dist def DFS_break(self, start, goal=None, transition = DIR_4): """ 壁をcost = 1で破壊できる それ以外の移動はcost = 0 input : (start,(goal),(transition)) output : dis(table) or goalまでのdis(int) """ H, W = self.H, self.W deq = deque() deq.append(start) self.dist[start[0]][start[1]] = 0 if start == goal: return 0 while deq: h,w = deq.popleft() for dh, dw in transition: nh = h+dh nw = w+dw # gridの範囲外. if (not 0 <= nh < H) or (not 0 <= nw < W): continue now_dist = self.dist[h][w] #goalが引数で与えられていてgoalに達したら終了. if goal and (nh,nw) == goal: return now_dist # wallに設定されている文字なら. if self.table[nh][nw] in self.wall: if self.dist[nh][nw] > now_dist+1: self.dist[nh][nw] = now_dist+1 deq.append((nh,nw)) elif self.dist[nh][nw] > now_dist: self.dist[nh][nw] = now_dist deq.appendleft((nh,nw)) if goal: return -1 return self.dist #バリエーションとして #方向変換したら距離加算 #壁破壊で距離加算 #壁の種類として他のものがある #視線が壁になる #マグネット #移動に制限がある(エネルギー) class RootedTree: """ __allmethod__ autobuild -> obj : inputから構築 set_root -> None is_root,is_leaf -> bool yield_edges -> Iterator ### set_weight -> None : weightのdict生成 get_weight -> int : dictから重さを取得 get_depth -> int : rootからの深さ ### build_depth -> None : 深さの構築 build_des_size -> None : centroid_decomposition : build_centroid_dist is_member_of_centroid_tree is_id_larger get_higher_centroids_with_self yield_centroid_children find_lowest_common_centroid """ @classmethod def autobuild(cls, N, root = 0, input_index = 1): """ (u,v) , (u,v,c)に対応 rootを設定したくないならNone """ G = [[] for _ in range(N)] if N == 1: obj = RootedTree(G) if root is not None: obj.set_root(0) return obj line1 = list(map(int, input().split())) assert 2 <= len(line1) <= 3 # 重み無し. if len(line1) == 2: u,v = line1 u,v = u-input_index, v-input_index G[u].append(v) G[v].append(u) for _ in range(N-2): u,v = map(int, input().split()) u,v = u-input_index, v-input_index G[u].append(v) G[v].append(u) obj = RootedTree(G) if root is not None: obj.set_root(0) return obj else: u,v,c = line1 u,v = u-input_index, v-input_index G[u].append(v) G[v].append(u) edge = [(u,v,c)] for _ in range(N-2): u,v,c = map(int, input().split()) u,v = u-input_index, v-input_index G[u].append(v) G[v].append(u) edge.append((u,v,c)) obj = RootedTree(G) obj.set_weight(edge) if root is not None: obj.set_root(0) return obj def __init__(self, G): self.N = len(G) self.G = G self._rooted = False self._has_weight = False self._key = 10**7 def set_root(self, root): """ DFSついでにトポロジカルソート列も求める """ assert self._rooted == False self.root = root n, G = self.N, self.G par, ch, ts = [-1]*n, [[] for _ in range(n)], [] deq = deque([root]) while deq: v = deq.popleft() ts.append(v) for adj in G[v]: if adj == par[v]: continue par[adj] = v ch[v].append(adj) deq.append(adj) self.parent, self.children, self.ts_order = par, ch, ts self._rooted = True def encode(self, u, v): #edgte -> int return u*self._key + v def decode(self, uv): #int -> edge return divmod(uv, self._key) def is_root(self, v) -> bool: return v == self.root def is_leaf(self, v) -> bool: return len(self.children[v]) == 0 def yield_edges(self) -> Iterator[tuple]: """rootに近い順にedgeを回すIterator""" N, ts, ch = self.N, self.ts_order, self.children if self._has_weight: wei, en = self.weight, self.encode for v in ts: for c in ch[v]: yield (v,c,wei[en(v,c)]) else: for v in ts: for c in ch[v]: yield (v,c) """ weight """ #edge->weightにO(1)でアクセスできるようにdictで持つ def set_weight(self, edge): assert self._has_weight == False d = {} for u,v,c in edge: d[self.encode(u,v)] = d[self.encode(v,u)] = c self.weight = d self._has_weight = True def get_weight(self, u, v) -> int: return self.weight[self.encode(u, v)] """depth : rootからの深さ""" def get_depth(self, v) -> int: # obj.depth[v] と同じ. if not hasattr(self, "depth"): self.build_depth() return self.depth[v] def build_depth(self): assert self._rooted N, ch, ts = self.N, self.children, self.ts_order depth = [0]*N for v in ts: for c in ch[v]: depth[c] = depth[v] + 1 self.depth = depth """subtree_size : 部分木""" def build_des_size(self): assert self._rooted if hasattr(self, "des_size"): return N, ts, par = self.N, self.ts_order, self.parent des = [1]*N for i in range(N-1,0,-1): v = ts[i] p = par[v] des[p] += des[v] self.des_size = des """centroid : 重心分解""" def centroid_decomposition(self, build_dist=True): """ centroid_id[i] : DFS的に重心分解をしたとき, 頂点iを重心とする重心木が何番目に登場するか. 頂点cenを重心とする重心木の頂点を探索する際は,頂点cenから, T.is_id_larger(v, cen)==True な頂点vのみを使って到達可能な頂点vを探索すればいい. centroid_dfs_order : centroid_id の逆順列. reveresed(centroid_dfs_order)順に重心木を探索することで より小さい重心木についての結果を用いたDPが可能. """ if hasattr(self, "centroid_id"): return # 根に依存しないアルゴリズムなので根0にしていい. if not self._rooted: self.set_root(0) if not hasattr(self, "des_size"): self.build_des_size() # sizeは書き換えるのでコピーを使用. N, G, size = self.N, self.G, self.des_size[:] c_id, c_depth, c_par, c_dfs_order = [-1]*N, [-1]*N, [-1]*N, [] stack = [(self.root, -1, 0)] # 重心を見つけたら,「重心分解後のその頂点が重心となる部分木」の # DFS順の順番, 深さ, 重心木における親にあたる部分木の重心を記録 for order in range(N): v, prev, d = stack.pop() while True: for adj in G[v]: if c_id[adj] == -1 and size[adj]*2 > size[v]: # adjを今見ている部分木の根にし,sizeを書き換える. size[v], size[adj], v = size[v]-size[adj], size[v], adj break else: break c_id[v], c_depth[v], c_par[v] = order, d, prev c_dfs_order.append(v) if size[v] > 1: for adj in G[v]: if c_id[adj] == -1: stack.append((adj, v, d+1)) self.centroid_id, self.centroid_depth, self.centroid_parent, self.centroid_dfs_order = c_id, c_depth, c_par, c_dfs_order if build_dist == True: self.build_centroid_dist() def build_centroid_dist(self): """ 重心同士を結んだ木を重心分解木と呼ぶことにする. 重心分解木のみを考えて解けるなら楽だが、 「各重心木における重心(根)との距離」 を求めるには元の辺の情報が必要.一方それさえ求めれば、 重心分解木に対する考察だけで足りる問題が多い. """ if hasattr(self, "centroid_dist"): return False if not hasattr(self, "centroid_id"): self.centroid_decomposition() N, G, c_depth = self.N, self.G ,self.centroid_depth is_id_larger = self.is_id_larger log = max(c_depth) + 1 # dist[d][v] : vが深さdの重心木に属しているならその重心からの距離. dist = [[-1]*N for _ in range(log)] for cen in range(N): d = c_depth[cen] stack = [cen] dist[d][cen] = 0 while stack: v = stack.pop() for adj in G[v]: if dist[d][adj] == -1 and is_id_larger(adj, cen): if self._has_weight: dist[d][adj] = dist[d][v] + self.weight[self.encode(v, adj)] else: dist[d][adj] = dist[d][v] + 1 stack.append(adj) self.centroid_log, self.centroid_dist = log, dist def is_member_of_centroid_tree(self, v, c): # 頂点vが重心cの重心木に属するかを判定 O(logN) vs = self.get_higher_centroids_with_self(v) return c in vs def is_id_larger(self, u, v): # 重心cからBFSする時に、is_id_larger(adj, c)とすれば重心木内部を探索できる. return self.centroid_id[u] > self.centroid_id[v] def get_higher_centroids_with_self(self, c): # 頂点cが属する重心木の重心をサイズの昇順に列挙. O(logN) vs = [] for d in range(self.centroid_depth[c], -1, -1): vs.append(c) c = self.centroid_parent[c] return vs def yield_centroid_children(self, v): # 頂点vを重心とする重心木における, # 「『vの子供を根とした部分木』と構成が同じ重心木の重心」を列挙する. # 「『重心木』の木」における「『vを重心とする重心木』の子の重心木」の重心 ともいえる. G, is_id_larger, c_par = self.G, self.is_id_larger, self.centroid_parent for ch in G[v]: if is_id_larger(ch, v): ch_cen = ch while c_par[ch_cen] != v: ch_cen = c_par[ch_cen] yield (ch, ch_cen) def find_lowest_common_centroid(self, u, v): # 頂点u,vをどちらも含む最小の重心木を返す. O(logN) c_depth, c_par = self.centroid_depth, self.centroid_parent du, dv = c_depth[u], c_depth[v] if du > dv: u,v = v,u du,dv = dv,du for _ in range(dv - du): v = c_par[v] while u != v: u,v = c_par[u],c_par[v] return u def build_the_centroid(self): """ 全体の重心だけで十分な時用 O(N) """ if not self._rooted: self.set_root(0) if hasattr(self, "the_centroid"): return False if hasattr(self, "centroid_id"): self.the_centroid = self.centroid_id[0] return True if not hasattr(self, "des_size"): self.build_des_size() N, ch, size = self.N, self.children, self.des_size v = self.root while True: for c in ch[v]: if size[c] > N // 2: v = c break else: self.the_centroid = v return True def get_the_centroid(self): if hasattr(self, "centroid_id"): return self.centroid_id[0] if not hasattr(self, "the_centroid"): self.build_the_centroid() return self.the_centroid """ tree dp """ def dp_from_leaf(self, merge, e, add_root, push=lambda obj,data,dst,src:data): """ チートシート 部分木の大きさ : dp_from_leaf(lambda x,y:x+y, 0, lambda x,y,z:y+1) """ assert self._rooted # pushで形整えたデータを親の単位元で初期化されたノードにmerge. # 子が全部mergeされたらadd_rootで自身の頂点の情報を追加. N, ts, par = self.N, self.ts_order, self.parent sub = [e] * N for i in range(N-1,-1,-1): v = ts[i] sub[v] = add_root(self, sub[v], v) p = par[v] if p != -1: sub[p] = merge(sub[p], push(self, sub[v], p, v)) return sub def rerooting_dp(self, merge, e, add_root, push=lambda obj,data,dst,src:data): """全方位木DP 途中で頂点を変更する""" if self._rooted == False: self.set_root(0) sub = self.dp_from_leaf(merge, e, add_root, push) N = self.N ts, par, ch = self.ts_order, self.parent, self.children compl, dp = [e]*N, [e]*N for i in range(N): v = ts[i] p, size = par[v], len(ch[v]) left, right = [e]*size, [e]*size for j in range(size): c = ch[v][j] left[j] = merge(left[j-1] if j>0 else e, push(self, sub[c], v, c)) for j in range(size-1,-1,-1): c = ch[v][j] right[j] = merge(right[j+1] if j<size-1 else e, push(self, sub[c], v, c)) for j in range(size): c = ch[v][j] compl[c] = merge(compl[c], left[j-1] if j>0 else e) compl[c] = merge(compl[c], right[j+1] if j<size-1 else e) if p != -1: compl[c] = merge(compl[c], push(self, compl[v], v, p)) compl[c] = add_root(self, compl[c], v) if p != -1: dp[v] = merge(dp[v], push(self, compl[v], v, p)) dp[v] = merge(dp[v], left[-1] if size else e) dp[v] = add_root(self, dp[v], v) return dp """ dist """ def build_dist_from_root(self, op = lambda x,y : x+y): assert self._rooted if hasattr(self, "dist_from_root"): return N, ts, ch = self.N, self.ts_order, self.children dist = [0]*N if self._has_weight: wei, en = self.weight, self.encode else: wei, en = [1], lambda a,b:0 for v in ts: for c in ch[v]: dist[c] = op(dist[v], wei[en(v, c)]) self.dist_from_root = dist def calc_dist_from_a_node(self, v, op = lambda x,y : x+y): """ v -> children[v] のdist """ N, G = self.N, self.G dist, que = [None]*N, [v] dist[v] = 0 if self._has_weight: wei, en = self.weight, self.encode else: wei, en = [1], lambda a,b:0 while que: v = que.pop() for adj in G[v]: if dist[adj] is None: dist[adj] = op(dist[v], wei[en(v, adj)]) que.append(adj) return dist def build_diameter(self): """直径を求める""" self.build_dist_from_root() if hasattr(self, "diameter"): return dist_r = self.dist_from_root v = dist_r.index(max(dist_r)) dist_v = self.calc_dist_from_a_node(v) dia = max(dist_v) u = dist_v.index(dia) self.diameter, self.end_points_of_diameter = dia, [v, u] def get_diameter(self): """直径の取得""" if hasattr(self, "diameter"): return self.diameter self.build_diameter() return self.diameter main() """==================fold line 1800=================="""