結果

問題 No.2852 Yakitori Optimization Problem
ユーザー はるるんはるるん
提出日時 2024-08-25 13:38:58
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 239 ms / 2,000 ms
コード長 57,228 bytes
コンパイル時間 289 ms
コンパイル使用メモリ 91,460 KB
実行使用メモリ 142,084 KB
最終ジャッジ日時 2024-08-25 13:39:13
合計ジャッジ時間 4,810 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 127 ms
95,832 KB
testcase_01 AC 191 ms
127,712 KB
testcase_02 AC 171 ms
118,960 KB
testcase_03 AC 233 ms
137,072 KB
testcase_04 AC 197 ms
126,308 KB
testcase_05 AC 125 ms
97,988 KB
testcase_06 AC 114 ms
81,068 KB
testcase_07 AC 198 ms
128,836 KB
testcase_08 AC 96 ms
79,128 KB
testcase_09 AC 229 ms
142,084 KB
testcase_10 AC 228 ms
137,252 KB
testcase_11 AC 233 ms
137,512 KB
testcase_12 AC 235 ms
137,124 KB
testcase_13 AC 232 ms
137,380 KB
testcase_14 AC 234 ms
137,124 KB
testcase_15 AC 228 ms
137,124 KB
testcase_16 AC 239 ms
136,944 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    n,k = MI()
    a = LI()
    b = LI()
    c = LI()
    ans = sum(a) + sum(c)
    d = [b[i]-c[i] for i in range(n)]
    d.sort()
    for i in range(k):
        ans += d[~i]

    print(ans)
    pass
"""==================fold line=================="""

"""import"""
# array
from bisect import bisect_left,bisect_right
from heapq import heapify,heappop,heappush
from collections import deque,defaultdict,Counter

# math
import math,random,cmath
from math import comb,ceil,floor,factorial,gcd,lcm,atan2,sqrt,isqrt,pi,e
from itertools import product,permutations,combinations,accumulate
from functools import cmp_to_key, cache

# system
from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar, Optional
T = TypeVar('T')
import sys
sys.setrecursionlimit(10**9)

"""input"""
#int-input
def II() -> int : return int(input())
def MI() -> int : return map(int, input().split())
def TI() -> tuple[int] : return tuple(MI())
def LI() -> list[int] : return list(MI())
#str-input
def SI() -> str : return input()
def MSI() -> str : return input().split()
def SI_L() -> list[str] : return list(SI())
def SI_LI() -> list[int] : return list(map(int, SI()))
#multiple-input
def LLI(n) -> list[list[int]]: return [LI() for _ in range(n)]
def LSI(n) -> list[str]: return [SI() for _ in range(n)]
#1-indexを0-indexでinput
def MI_1() -> int : return map(lambda x:int(x)-1, input().split())
def TI_1() -> tuple[int] : return tuple(MI_1())
def LI_1() -> list[int] : return list(MI_1())

def ordalp(s:str) -> int|list[int]:
    if len(s) == 1:
        return ord(s)-ord("A") if s.isupper() else ord(s)-ord("a")
    return list(map(lambda i: ord(i)-ord("A") if i.isupper() else ord(i)-ord("a"), s))

def ordallalp(s:str) -> int|list[int]:
    if len(s) == 1:
        return ord(s)-ord("A")+26 if s.isupper() else ord(s)-ord("a")
    return list(map(lambda i: ord(i)-ord("A")+26 if i.isupper() else ord(i)-ord("a"), s))

def graph(n:str, m:str, dir:bool=False , index=-1) -> list[set[int]]:
    """
    (頂点,辺,有向か,indexの調整)
    defaultは無向辺、(index)-1
    """
    edge = [set() for i in range(n+1+index)]
    for _ in range(m):
        a,b = map(int, input().split())
        a,b = a+index,b+index
        edge[a].add(b)
        if not dir:
            edge[b].add(a)
    return edge

def graph_w(n:str, m:str, dir:bool=False , index=-1) -> list[set[tuple]]:
    """
    (頂点,辺,有向か,indexの調整)
    defaultは無向辺、index-1
    """
    edge = [set() for i in range(n+1+index)]
    for _ in range(m):
        a,b,c = map(int, input().split())
        a,b = a+index,b+index
        edge[a].add((b,c))
        if not dir:
            edge[b].add((a,c))
    return edge

"""const"""
mod, inf = 998244353, 1<<60
isnum = {int,float,complex}
true, false, none = True, False, None
def yes() -> None: print("Yes")
def no() -> None: print("No")
def yn(flag:bool) -> None: print("Yes" if flag else "No")
def ta(flag:bool) -> None: print("Takahashi" if flag else "Aoki")
alplow = "abcdefghijklmnopqrstuvwxyz"
alpup = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
alpall = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
URDL = {'U':(-1,0), 'R':(0,1), 'D':(1,0), 'L':(0,-1)}
DIR_4 = [[-1,0],[0,1],[1,0],[0,-1]]
DIR_8 = [[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1],[0,-1],[-1,-1]]
DIR_BISHOP = [[-1,1],[1,1],[1,-1],[-1,-1]]
prime60 = [2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59]

# alias
DD = defaultdict
BSL = bisect_left
BSR = bisect_right

"""math fanctions"""

"""point"""
cross_pro = lambda p1,p2 : p2[0]*p1[1] - p2[1]*p1[0] #外積
dist = lambda p1,p2 : sqrt(pow(p1[0]-p2[0],2) + pow(p1[1]-p2[1],2))

def max_min_cross(p1, p2, p3, p4, touch = False):
    min_ab, max_ab = min(p1, p2), max(p1, p2)
    min_cd, max_cd = min(p3, p4), max(p3, p4)

    if touch:
        if min_ab > max_cd or max_ab < min_cd:
            return False
        return True
    else:
        if min_ab >= max_cd or max_ab <= min_cd:
            return False
        return True

def cross_judge(a, b, c, d, touch = False):
    """線分abとcdの交差判定 接するも含むかどうか"""
    # x座標による判定
    if not max_min_cross(a[0], b[0], c[0], d[0], touch):
        return False

    # y座標による判定
    if not max_min_cross(a[1], b[1], c[1], d[1], touch):
        return False

    tc1 = (a[0] - b[0]) * (c[1] - a[1]) + (a[1] - b[1]) * (a[0] - c[0])
    tc2 = (a[0] - b[0]) * (d[1] - a[1]) + (a[1] - b[1]) * (a[0] - d[0])
    td1 = (c[0] - d[0]) * (a[1] - c[1]) + (c[1] - d[1]) * (c[0] - a[0])
    td2 = (c[0] - d[0]) * (b[1] - c[1]) + (c[1] - d[1]) * (c[0] - b[0])
    if touch:
        return tc1 * tc2 <= 0 and td1 * td2 <= 0
    else:
        return tc1 * tc2 < 0 and td1 * td2 < 0

"""primary function"""
def prod(lst:list[int|str], mod = None) -> int|str:
    """product 文字列の場合連結"""
    ans = 1
    if type(lst[0]) in isnum:
        for i in lst:
            ans *= i
            if mod: ans %= mod
        return ans
    else:
        return "".join(lst)

def sigma(first:int, diff:int, term:int) -> int: #等差数列の和
    return term*(first*2+(term-1)*diff)//2

def xgcd(a:int, b:int) -> tuple[int,int,int]: #Euclid互除法
    """ans = a*x0 + b*y0"""
    x0, y0, x1, y1 = 1, 0, 0, 1
    while b != 0:
        q, a, b = a // b, b, a % b
        x0, x1 = x1, x0 - q * x1
        y0, y1 = y1, y0 - q * y1
    return a, x0, y0

def modinv(a:int, mod = mod) -> int: #逆元
    """逆元"""
    g, x, y = xgcd(a, mod)
    #g != 1は逆元が存在しない
    return -1 if g != 1 else x % m 

def nth_root(x:int, n:int, is_x_within_64bit = True) -> int: #n乗根
    """floor(n√x)"""
    ngs = [-1, -1, 4294967296, 2642246, 65536, 7132, 1626, 566, 256, 139, 85, 57, 41, 31, 24, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
    if x <= 1 or n == 1:
        return x
    if is_x_within_64bit:
        if n >= 64:
            return 1
        ng = ngs[n]
    else:
        ng = x

    ok = 0
    while abs(ok - ng) > 1:
        mid = (ok + ng)//2
        if mid**n <= x:
            ok = mid
        else:
            ng = mid
    return ok    

def pi_base(p:list) -> Iterator:
    l = len(p)
    num = [0]*l
    while True:
        yield num
        num[~0] += 1
        for i in range(l):
            if num[~i] == p[~i]:
                if i == l-1:
                    return
                num[~i] = 0
                num[~(i+1)] += 1
            else:
                break

"""prime"""
def primefact(n:int) -> dict[int,int]: #素因数分解
    """素因数分解"""
    i = 2
    pdict = dict()
    while i*i <= n:
        if n%i == 0:
            cnt = 0
            while n%i == 0:
                n //= i
                cnt += 1
            pdict[i] = cnt
        i += 1
    if n != 1:
        pdict[n] = 1
    
    return pdict

def primenumber(lim:int, get = None) -> list[int]: #素数列挙
    """
    素数列挙 sieve(n)もあります
    get == None : リスト
    get >= 1 : flag
    get < 1 : 累積和
    """
    lim += 1
    #素数にはflagを立てる
    p = [1]*lim
    #それ以下の素数の数を保管
    cntp = [0]*lim
    #素数列を格納
    plist = []

    p[0],p[1] = 0,0
    for i in range(2,lim):
        if p[i]:
            plist.append(i)
            for j in range(2*i,lim,i):
                p[j] = 0

    for i in range(1,lim):
        cntp[i] = cntp[i-1] + p[i]
    
    if get is None:
        return plist
    elif get >= 1:
        return p
    else:
        return cntp

def divisors(n:int) -> list[int] : #約数列挙
    """約数列挙"""
    divs_small, divs_big = [], []
    i = 1
    while i * i <= n:
        if n % i == 0:
            divs_small.append(i)
            divs_big.append(n // i)
        i += 1
    if divs_big[-1] == divs_small[-1]:
        divs_big.pop()
    for e in reversed(divs_big):
        divs_small.append(e)
    return divs_small

"""binary number"""
lenbit = lambda bit: (bit).bit_length()

def popcnt(n:int) -> int: #popcnt
    """int.bit_count() があります 64bitまで"""
    c=(n&0x5555555555555555)+((n>>1)&0x5555555555555555)
    c=(c&0x3333333333333333)+((c>>2)&0x3333333333333333)
    c=(c&0x0f0f0f0f0f0f0f0f)+((c>>4)&0x0f0f0f0f0f0f0f0f)
    c=(c&0x00ff00ff00ff00ff)+((c>>8)&0x00ff00ff00ff00ff)
    c=(c&0x0000ffff0000ffff)+((c>>16)&0x0000ffff0000ffff)
    c=(c&0x00000000ffffffff)+((c>>32)&0x00000000ffffffff)
    return c

def binchange(n:int,fill0 = None) -> str:
    """10進数(int)→2進数(str) fill0:0埋め桁数"""
    return format(n, "0"+str(fill0)+"b") if fill0 else format(n,"b")

"""list"""
def prefix_op(lst:list, op = lambda x,y:x+y, e = 0) -> list: #累積和
    """defaultは累積和"""
    n = len(lst)
    res = [e]*(n+1)
    for i in range(n):
        res[i+1] = op(res[i], lst[i])
    return res

def suffix_op(lst:list, op = lambda x,y:x+y, e = 0) -> list: #累積和
    """defaultは累積和"""
    n = len(lst)
    res = [e]*(n+1)
    for i in reversed(range(n)):
        res[i] = op(res[i+1], lst[i])
    return res

def acc_sum(lst:list, dim = 2) -> list:
    if dim == 2:
        h,w = len(lst),len(lst[0])
        res = [[0]*(w+1)]
        for i in range(h):
            res.append([0])
            for j in range(w):
                res[-1].append(res[i+1][j] + lst[i][j])
    
        for j in range(w):
            for i in range(h):
                res[i+1][j+1] += res[i][j+1]
        
        return res
    
    elif dim == 3:
        d1,d2,d3 = len(lst),len(lst[0]),len(lst[0][0])
        res = [[[0]*(d3+1) for i in range(d2+1)]]
        for i in range(d1):
            res.append([[0]*(d3+1)])
            for j in range(d2):
                res[-1].append([0])
                for k in range(d3):
                    res[-1][-1].append(res[i+1][j+1][k] + lst[i][j][k])
        
        for j in range(d2):
            for k in range(d3):
                for i in range(d1):
                    res[i+1][j+1][k+1] += res[i][j+1][k+1]
            
        for k in range(d3):
            for i in range(d1):
                for j in range(d2):
                    res[i+1][j+1][k+1] += res[i+1][j][k+1]
        
        return res

def mex(lst:list) -> int:
    """補集合の最小非負整数"""
    l = set(lst)
    ans = 0
    while ans in l:
        ans += 1
    return ans

def inversion_cnt(lst:list, flag = None) -> int: #転倒数
    """転倒数 not順列→flag立てる"""
    n = len(lst)
    if not flag is None:
        comp = Compress(lst)
        lst = comp.comp
    else:
        lst = list(map(lambda x : x-1, lst))

    ft = FenwickTree(n)
    ans = [0]*n #i要素目の転倒への寄与
    
    for i in range(n):
        ans[i] = ft.sum(lst[i]+1,n)
        ft.add(lst[i], 1)
    
    return ans

def doubling(nex:list, k:int = None ,a:list = None) -> list:
        """nex:操作列 k:回数 a:初期列"""
        n = len(nex)
        
        if k is None:
            log = 60
        else:
            log = (k+1).bit_length()
        
        res = [[-1]*n for _ in range(log)] #ダブリング配列
        res[0] = nex[:]
        
        for cnt in range(1,log):
            for i in range(n):
                tmp = res[cnt-1][i]
                res[cnt][i] = res[cnt-1][tmp]
                
        if k is None:
            return res
        
        ans = (nex[:] if a is None else a[:])
        for cnt in range(log):
            if k & (1<<cnt) != 0:
                ans = [ans[res[cnt][i]] for i in range(n)]
        return ans

def swapcnt(a:list, b:list) -> int:
    """
    順列(同じ要素がない)が前提
    最小操作回数を返す
    """
    if sorted(a) != sorted(b):
        return -1
    
    t = dict()
    cnt = 0
    for i in range(n):
        x,y = a[i],b[i]
        if x == y:
            continue
        if x in t:
            while x in t:
                x_ = t[x]
                del t[x]
                x = x_
                cnt += 1
                if x == y:
                    break
            else:
                t[y] = x
        else:
            t[y] = x
        
    return cnt

"""matrix"""
def mul_matrix(A, B, mod = mod): #行列の積 A*B
    N = len(A)
    K = len(A[0])
    M = len(B[0])

    res = [[0 for _ in range(M)] for _ in range(N)]

    for i in range(N) :
        for j in range(K) :
            for k in range(M) :
                res[i][k] += A[i][j] * B[j][k] 
                res[i][k] %= mod
    return res

def pow_matrix(mat, exp, mod = mod): #二分累乗
    N = len(mat)
    res = [[1 if i == j else 0 for i in range(N)] for j in range(N)]
    while exp > 0 :
        if exp%2 == 1 :
            res = mul_matrix(res, mat, mod)
        mat = mul_matrix(mat, mat, mod)
        exp //= 2
    return res

"""enumerate"""
def fact_enu(lim): #階乗列挙
    #階乗
    fac = [1]
    #階乗の逆数
    divfac = [1]

    factorial = 1
    for i in range(1,lim+1):
        factorial *= i
        factorial %= mod
        fac.append(factorial)
        divfac.append(pow(factorial,-1,mod))
    return fac,divfac

class Comb_enu: #combination列挙
    def __init__(self,lim,mod = mod):
        """
        mod : prime指定
        lim以下のmodでcomdination計算
        """
        self.fac = [1,1]
        self.inv = [1,1]
        self.finv = [1,1]
        self.mod = mod
        for i in range(2,lim+1):
            self.fac.append(self.fac[i-1]*i%mod)
            self.inv.append(-self.inv[mod%i]*(mod//i)%mod)
            self.finv.append(self.finv[i-1]*self.inv[i]%mod)

    def comb(self,a,b):
        if a < b:
            return 0
        if a < 0:
            return 0
        return self.fac[a]*self.finv[b]*self.finv[a-b]%self.mod

"""str"""
def int_0(str,l,r = None, over_ok = False): #str→int
        """
        strの[l,r)桁をintで返す(0-index)
        取れない場合はNone
        over_okを立てればrが桁を超えても返す
        """
        lstr = len(str)
        if l > len(str):
            return None
        
        l = lstr - l
        if r == None:
            if "" == str[r:l]:
                return 0
            return int(str[:l])

        if r > len(str):
            if over_ok:
                return int(str[:l])
            else:
                return None

        r = lstr - r
        if "" == str[r:l]:
            return 0
        
        return int(str[r:l])

def lis(l): #後でちゃんと書き直してね
    # STEP1: LIS長パート with 使用位置
    n = len(l)
    lisDP = [inf] * n # いまi文字目に使っている文字
    indexList = [None] * n # lの[i]文字目が使われた場所を記録する
    for i in range(n):
        # 通常のLISを求め、indexListに使った場所を記録する
        ind = bisect_left(lisDP, l[i])
        lisDP[ind] = l[i]
        indexList[i] = ind

    # STEP2: LIS復元パート by 元配列の使用した位置
    # 後ろから見ていくので、まずは、LIS長目(targetIndex)のindexListを探したいとする
    targetIndex = max(indexList)
    ans = [0] * (targetIndex + 1) # 復元結果(indexListは0-indexedなのでlen=4ならmax=3で格納されているので+1する)
    # 後ろから見ていく
    for i in range(n - 1, -1, -1):
        # もし、一番最後に出てきているtargetIndexなら
        if indexList[i] == targetIndex:
            ans[targetIndex] = l[i] # ansのtargetIndexを確定
            targetIndex -= 1
    return ans

"""table operation"""
def copy_table(table):
    H,W = len(table), len(table[0])
    res = []
    for i in range(H):
        res.append([])
        for j in range(W):
            res[-1].append(table[i][j])
    return res

def rotate_table(table): #反時計回りに回転
    return list(map(list, zip(*table)))[::-1]

def transpose_table(l): #行と列を入れ替え
    return [list(x) for x in zip(*l)]

def bitconvert_table(table, letter1="#", rev=False): #各行bitに変換
    H,W = len(table), len(table[0])
    res = []
    for h in range(H):
        rowBit = 0
        for w in range(W):
            if rev:
                if table[h][w] == letter1:
                    rowBit += 1<<w
            else:
                if table[h][W-w-1] == letter1:
                    rowBit += 1<<w
        res.append(rowBit)
    return res

"""sort"""
def argment_sort(points): #偏角ソート
    yposi,ynega = [],[]
    for x,y in points:
        if y > 0 or (y == 0 and x >= 0):
            yposi.append([x,y])
        else:
            ynega.append([x,y])
    yposi.sort(key = cmp_to_key(cross_pro))
    ynega.sort(key = cmp_to_key(cross_pro))
    
    return yposi+ynega

def quick_sort(lst, comparision, left = 0, right = -1):
    """
    (list,比較関数,(l),(r))
    input : (p,q)
    output : True (p<q)
    """
    i = left
    if right == -1:
        right %= len(lst)
    j = right
    pivot = (i+j)//2
    dpivot = lst[pivot]

    while True:
        #条件式
        while comparision(lst[i],dpivot):
            i += 1
        while comparision(dpivot,lst[j]):
            j -= 1
        if i >= j:
            break

        lst[i],lst[j] = lst[j],lst[i]
        i += 1
        j -= 1
    
    if left < i - 1:
        quick_sort(lst, left, i - 1)
    if right > j + 1:
        quick_sort(lst, j + 1, right)

def bubble_sort(lst):
    """返り値:転倒数"""
    cnt = 0
    n = len(lst)
    for i in range(n):
        for j in reversed(range(i+1),n):
            if a[j] > a[j-1]:
                a[j],a[j-1] = a[j-1],a[j]
                cnt += 1
    return cnt

def topological_sort(egde, inedge=None):
    n = len(edge)
    
    if inedge == None:
        inedge = [0]*n
        for v in range(n):
            for adj in edge(v):
                inedge[adj] += 1
    
    ans = [i for i in range(n) if inedge[i] == 0]
    que = deque(ans)
    while que:
        q = que.popleft()
        for e in edge[q]:
            inedge[e] -= 1
            if inedge[e] == 0:
                que.append(e)
                ans.append(e)
    return ans

"""graph fanctions"""
def dijkstra(edge, start=0, goal=None):
    """計算量 O((node+edge)log(edge))"""
    n = len(edge)
    dis = [inf]*n
    dis[start] = 0
    que = [(0, start)]
    heapify(que)

    while que:
        cur_dis,cur_node = heappop(que)

        if dis[cur_node] < cur_dis:
            continue

        for next_node, weight in edge[cur_node]:
            next_dis = cur_dis + weight

            if next_dis < dis[next_node]:
                dis[next_node] = next_dis
                heappush(que, (next_dis, next_node))
    
    if goal != None: return dis[goal]
    return dis

def warshallfloyd(dis):
    n = len(dis)
    for i in range(n):
        dis[i][i] = 0

    for k in range(n):
        for i in range(n):
            for j in range(n):
                dis[i][j] = min(dis[i][j], dis[i][k]+dis[k][j])
    return dis

def bellmanford(edge, start=0, goal=None):
    """
    始点と終点が決まっている
    始点から到達可能かつ、終点に到達可能な閉路のみ検出
    """
    n = len(edge)
    dis = [inf]*n
    pre = [-1]*n #最短経路における直前にいた頂点
    negative = [False]*n #たどり着くときに負の閉路があるかどうか
    dis[start] = 0

    for t in range(2*n):
        for u in range(n):
            for v, cost in edge[u]:
                if dis[v] > dis[u] + cost:
                    if t >= n-1 and v == goal:
                        return None #0と衝突しないように
                    elif i >= n-1:
                        dis[v] = -inf
                    else:
                        dis[v] = dis[u] + cost
                        pre[v] = u
    
    return dis[goal] #通常はここで終わり
    
    #最短経路の復元
    x = goal
    path = [x]
    while x != start:
        x = pre[x]
        path.append(x)
        
    #最短経路を含む負の閉路があるかどうか
    for i in reversed(range(len(path)-1)):
        u, v = path[i+1], path[i]
        if dis[v] > dis[u] + cost:
            dis[v] = dis[u] + cost
            negative[v] = True
        if negative[u]:
            negative[v] = True
            
    if negative[end]:
        return -1
    else:
        return d[end]

#ループ検出書くの嫌いなので用意しましょう
def loop(g):
    pass

"""data stucture"""
#双方向リスト
# https://github.com/tatyam-prime/SortedSet?tab=readme-ov-file
class SortedSet(Generic[T]):
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a: Iterable[T] = []) -> None:
        "Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)"
        a = list(a)
        n = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        if any(a[i] >= a[i + 1] for i in range(n - 1)):
            a, b = [], a
            for x in b:
                if not a or a[-1] != x:
                    a.append(x)
        n = self.size = len(a)
        num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i: yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedSet" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x: T) -> Tuple[List[T], int, int]:
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x: T) -> bool:
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def add(self, x: T) -> bool:
        "Add an element and return True if added. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return True
        a, b, i = self._position(x)
        if i != len(a) and a[i] == x: return False
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
        return True
    
    def _pop(self, a: List[T], b: int, i: int) -> T:
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x: T) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True
    
    def lt(self, x: T) -> Optional[T]:
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Optional[T]:
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Optional[T]:
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Optional[T]:
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int) -> T:
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1) -> T:
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError
    
    def index(self, x: T) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans

class SortedList(Generic[T]):
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a: Iterable[T] = []) -> None:
        "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
        a = list(a)
        n = self.size = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i: yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedMultiset" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x: T) -> Tuple[List[T], int, int]:
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x: T) -> bool:
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def count(self, x: T) -> int:
        "Count the number of x."
        return self.index_right(x) - self.index(x)

    def add(self, x: T) -> None:
        "Add an element. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return
        a, b, i = self._position(x)
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
    
    def _pop(self, a: List[T], b: int, i: int) -> T:
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x: T) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True

    def lt(self, x: T) -> Optional[T]:
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Optional[T]:
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Optional[T]:
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Optional[T]:
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int) -> T:
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1) -> T:
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError

    def index(self, x: T) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans

class Deque: #両端以外もO(1)でアクセスできるdeque
    def __init__(self, src_arr=[], max_size=300000):
        self.N = max(max_size, len(src_arr)) + 1
        self.buf = list(src_arr) + [None] * (self.N - len(src_arr))
        self.head = 0
        self.tail = len(src_arr)
    def __index(self, i):
        l = len(self)
        if not -l <= i < l: raise IndexError('index out of range: ' + str(i))
        if i < 0:
            i += l
        return (self.head + i) % self.N
    def __extend(self):
        ex = self.N - 1
        self.buf[self.tail+1 : self.tail+1] = [None] * ex
        self.N = len(self.buf)
        if self.head > 0:
            self.head += ex
    def is_full(self):
        return len(self) >= self.N - 1
    def is_empty(self):
        return len(self) == 0
    def append(self, x):
        if self.is_full(): self.__extend()
        self.buf[self.tail] = x
        self.tail += 1
        self.tail %= self.N
    def appendleft(self, x):
        if self.is_full(): self.__extend()
        self.buf[(self.head - 1) % self.N] = x
        self.head -= 1
        self.head %= self.N
    def pop(self):
        if self.is_empty(): raise IndexError('pop() when buffer is empty')
        ret = self.buf[(self.tail - 1) % self.N]
        self.tail -= 1
        self.tail %= self.N
        return ret
    def popleft(self):
        if self.is_empty(): raise IndexError('popleft() when buffer is empty')
        ret = self.buf[self.head]
        self.head += 1
        self.head %= self.N
        return ret
    def __len__(self):
        return (self.tail - self.head) % self.N
    def __getitem__(self, key):
        return self.buf[self.__index(key)]
    def __setitem__(self, key, value):
        self.buf[self.__index(key)] = value
    def __str__(self):
        return 'Deque({0})'.format(str(list(self)))

class WeightedUnionFind: #重み付きunion-find
    def __init__(self, N):
        self.N = N
        self.parents = [-1] * N
        self.rank = [0] * N
        self.weight = [0] * N

    def root(self, x):
        if self.parents[x] == -1:
            return x
        rx = self.root(self.parents[x])
        self.weight[x] += self.weight[self.parents[x]]
        self.parents[x] = rx
        return self.parents[x]
    
    def get_weight(self, x):
        self.root(x)
        return self.weight[x]

    def unite(self, x, y, d):
        '''
        A[x] - A[y] = d
        '''
        w = d + self.get_weight(x) - self.get_weight(y)
        rx = self.root(x)
        ry = self.root(y)
        if rx == ry:
            _, d_xy = self.diff(x, y)
            if d_xy == d:
                return True
            else:
                return False
        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
            w = -w
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1
        
        self.parents[ry] = rx
        self.weight[ry] = w
        return True

    def is_same(self, x, y):
        return self.root(x) == self.root(y)
    
    def diff(self, x, y):
        if self.is_same(x, y):
            return True, self.get_weight(y) - self.get_weight(x)
        else:
            return False, 0

"""binary search"""
def bi_int(pred, ok = 0, ng = inf):
    """
    [lowlim,ans)だとTrueで[ans,uplim)だとFalse
    のイメージで実装
    """
    if not pred(ok):
        #条件を満たすことがない
        return ok
    
    while abs(ng - ok) > 1:
        mid = ok + (ng - ok)//2
        (ok := mid) if pred(mid) else (ng := mid)
    
    return ok

def bi_float(pred, ok = 0, ng = inf, error = 10**(-9)):
    """
    [lowlim,ans)だとTrueで[ans,uplim)だとFalse
    のイメージで実装
    """
    if not pred(ok):
        #条件を満たすことがない
        return ok

    #相対誤差と絶対誤差のどちらかがerror以下で終了
    while abs(ng - ok)/abs(ng) > error and abs(ng - ok) >  eroor:
        mid = ok + (ng - ok)/2
        (ok := mid) if pred(mid) else (ng := mid)
    
    return ok

"""compress"""
class Compress: #座標圧縮(一次元)
    def __init__(self, arr):
        values = sorted(set(arr))
        self.translator = dict([(values[i], i) for i in range(len(values))])
        self.inv_translator = values
        self.comp = []
        for x in arr:
            self.comp.append(self.translator[x])

    #圧縮前→圧縮後
    def to_comp(self, x):
        return self.translator[x]

    #圧縮後→圧縮前
    def from_comp(self, v):
        return self.inv_translator[v]
    
    #lstを変換
    def lst_comp(self, lst):
        return [self.to_comp(i) for i in lst]

class Compress2D: #2次元リスト[x,y]の座標圧縮
    def __init__(self, arr):
        self.x = Compress([x for x, y in arr])
        self.y = Compress([y for x, y in arr])
        self.comp = []
        for x,y in arr:
            self.comp.append([self.x.translator[x],self.y.translator[y]])

    #圧縮前→圧縮後
    def to_comp(self, x):
        return (self.x.translator[x[0]], self.y.translator[x[1]])

    #圧縮後→圧縮前
    def from_comp(self, v):
        return (self.x.translator[v[0]], self.y.translator[v[1]])

class RollingHash: #hash化
    def __init__(self, string, base = 37, mod = 10**9 + 9):
        self.mod = mod

        l = len(string)
        self.hash = [0]*(l+1)

        for i in range(1,l+1):
            self.hash[i] = ( self.hash[i-1] * base + ord(string[i-1]) ) % mod

        self.pw = [1]*(l+1)
        for i in range(1,l+1):
            self.pw[i] = self.pw[i-1] * base % mod


    def get(self, l, r):
        """s[l:r]のhash"""
        return (self.hash[r] - self.hash[l] * self.pw[r-l]) % self.mod

class ZobristHash: #多重集合の一致判定
    def __init__(self, n, as_list:bool = False, mod = (1<<61)-1):
        self.N = n
        self.conversion = [random.randint(1, mod - 1) for i in range(n+1)]
        self.as_list = as_list #setとして扱うかlistの並び替えか
        self.Mod = mod
        
    def makehash(self, a:list):
        la = len(a)
        hashlst = [0]*(la+1)
        if self.as_list:
            #listの並び替えとしての一致
            for i in range(la):
                hashlst[i+1] = (hashlst[i]+self.conversion[a[i]])%self.Mod
                    
            return hashlst
        else:
            #setとしての一致
            cnt = {}
            for i in range(la):
                if a[i] in cnt:
                    hashlst[i+1] = hashlst[i+1]
                    continue
                
                cnt.add(a[i])
                hashlst[i+1] = hashlst[i]^self.conversion[a[i]]
                    
            return hashlst
        
    
    def get(self, hashedlst:list, l:int, r:int):
        """a[l:r]のhashを返します"""
        if self.as_list:
            return (hashedlst[r]-hashedlst[l])%self.Mod
        else:
            return hashedlst[r]^hashedlst[l]

"""畳み込み??"""

"""graph"""
class GridSearch:

    def __init__(self, table):
        """盤面の受取"""
        self.table = table
        self.H = len(table)
        self.W = len(table[0])
        self.wall = "#"
        self.dist = [[inf]*self.W for _ in range(self.H)]

    def find(self, c):
        """始点,終点等の取得"""
        for h in range(self.H):
            for w in range(self.W):
                if self.table[h][w] == c:
                    return (h,w)
        return None

    def set_wall(self, string): 
        """壁の設定"""
        self.wall = string
    
    def can_start(self, *start):
        """探索済みでないかつ壁でない"""
        if len(start) == 1:
            i,j = start[0][0],start[0][1]
        else:
            i,j = start[0],start[1]

        if self.dist[i][j] == inf and not self.table[i][j] in self.wall:
            return True
        else:
            return False

    def island(self, transition = DIR_4):
        """連結成分の検出"""
        H, W = self.H, self.W
        self.island_id = [[-1]*W for _ in range(H)]
        self.island_size = [[-1]*W for _ in range(H)]

        crr_id = 0
        id2size = dict()
        for sh in range(H):
            for sw in range(W):
                if self.table[sh][sw] in self.wall:
                    continue
                if self.island_id[sh][sw] != -1:
                    continue
                deq = deque()
                deq.append((sh,sw))
                crr_size = 1
                self.island_id[sh][sw] = crr_id
                while deq:
                    h,w = deq.popleft()
                    for dh, dw in transition:
                        nh, nw = h+dh, w+dw
                        if (not 0<= nh < H) or (not 0 <= nw < W):
                            continue
                        if self.table[nh][nw] in self.wall:
                            continue
                        if self.island_id[nh][nw] == -1:
                            self.island_id[nh][nw] = crr_id
                            deq.append((nh, nw))
                            crr_size += 1

                id2size[crr_id] = crr_size
                crr_id += 1

        for h in range(H):
            for w in range(W):
                if self.table[h][w] in self.wall:
                    continue
                self.island_size[h][w] = id2size[self.island_id[h][w]]

        return self.island_id, self.island_size


    def DFS(self, start, goal=None, transition = DIR_4):
        """
        DFSをします
        input : (start,(goal),(transition))
        output : dis(table) or goalまでのdis(int)
        """
        H, W = self.H, self.W

        deq = deque()
        deq.append(start)
        self.dist[start[0]][start[1]] = 0

        if start == goal:
            return 0

        while deq:
            h,w = deq.popleft()
            for dh, dw in transition:
                nh = h+dh
                nw = w+dw
                # gridの範囲外.
                if (not 0 <= nh < H) or (not 0 <= nw < W):
                    continue

                # wallに設定されている文字なら.
                if self.table[nh][nw] in self.wall:
                    continue

                new_dist = self.dist[h][w] + 1

                #goalが引数で与えられていてgoalに達したら終了.
                if goal and (nh,nw) == goal:
                    return new_dist
                
                if self.dist[nh][nw] > new_dist:
                    self.dist[nh][nw] = new_dist
                    deq.append((nh,nw))
        
        if goal:
            return -1

        return self.dist

    def DFS_break(self, start, goal=None, transition = DIR_4):
        """
        壁をcost = 1で破壊できる それ以外の移動はcost = 0
        input : (start,(goal),(transition))
        output : dis(table) or goalまでのdis(int)
        """
        H, W = self.H, self.W

        deq = deque()
        deq.append(start)
        self.dist[start[0]][start[1]] = 0

        if start == goal:
            return 0

        while deq:
            h,w = deq.popleft()
            for dh, dw in transition:
                nh = h+dh
                nw = w+dw
                # gridの範囲外.
                if (not 0 <= nh < H) or (not 0 <= nw < W):
                    continue

                now_dist = self.dist[h][w]

                #goalが引数で与えられていてgoalに達したら終了.
                if goal and (nh,nw) == goal:
                    return new_dist

                # wallに設定されている文字なら.
                if self.table[nh][nw] in self.wall:
                    if self.dist[nh][nw] > now_dist+1:
                        self.dist[nh][nw] = now_dist+1
                        deq.append((nh,nw))
                
                if self.dist[nh][nw] > now_dist:
                    self.dist[nh][nw] = now_dist
                    deq.appendleft((nh,nw))
        
        if goal:
            return -1

        return self.dist

    #バリエーションとして

    #方向変換したら距離加算
    #壁破壊で距離加算

    #壁の種類として他のものがある
    #視線が壁になる
    #マグネット
    
    #移動に制限がある(エネルギー)

class RootedTree:
    """
    __allmethod__
    
    autobuild -> obj : inputから構築
    set_root -> None
    is_root,is_leaf -> bool
    yield_edges -> Iterator
    ### set_weight -> None : weightのdict生成
    get_weight -> int : dictから重さを取得
    get_depth -> int : rootからの深さ
    ### build_depth -> None : 深さの構築
    build_des_size -> None : 
    centroid_decomposition :
    build_centroid_dist
    is_member_of_centroid_tree
    is_id_larger
    get_higher_centroids_with_self
    yield_centroid_children
    find_lowest_common_centroid
    """
    @classmethod
    def autobuild(cls, N, root = 0, input_index = 1):
        """ 
        (u,v) , (u,v,c)に対応
        rootを設定したくないならNone
        """
        G = [[] for _ in range(N)]
        if N == 1:
            obj = RootedTree(G)
            if root is not None:
                obj.set_root(0)
            return obj

        line1 = list(map(int, input().split()))
        assert 2 <= len(line1) <= 3

        # 重み無し.
        if len(line1) == 2:
            u,v = line1
            u,v = u-input_index, v-input_index
            G[u].append(v)
            G[v].append(u)
            for _ in range(N-2):
                u,v = map(int, input().split())
                u,v = u-input_index, v-input_index
                G[u].append(v)
                G[v].append(u)
            obj = RootedTree(G)
            if root is not None:
                obj.set_root(0)
            return obj

        else:
            u,v,c = line1
            u,v = u-input_index, v-input_index
            G[u].append(v)
            G[v].append(u)
            edge = [(u,v,c)]
            for _ in range(N-2):
                u,v,c = map(int, input().split())
                u,v = u-input_index, v-input_index
                G[u].append(v)
                G[v].append(u)
                edge.append((u,v,c))

            obj = RootedTree(G)
            obj.set_weight(edge)
            if root is not None:
                obj.set_root(0)
            return obj

    def __init__(self, G):
        self.N = len(G)
        self.G = G
        self._rooted = False
        self._has_weight = False
        self._key = 10**7

    def set_root(self, root):
        """ DFSついでにトポロジカルソート列も求める """
        assert self._rooted == False
        self.root = root
        n, G = self.N, self.G
        par, ch, ts = [-1]*n, [[] for _ in range(n)], []
        deq = deque([root])
        while deq:
            v = deq.popleft()
            ts.append(v)
            for adj in G[v]:
                if adj == par[v]: continue
                par[adj] = v
                ch[v].append(adj)
                deq.append(adj)
        self.parent, self.children, self.ts_order = par, ch, ts
        self._rooted = True

    def encode(self, u, v): #edgte -> int 
        return u*self._key + v

    def decode(self, uv): #int -> edge
        return divmod(uv, self._key)

    def is_root(self, v) -> bool:
        return v == self.root

    def is_leaf(self, v) -> bool:
        return len(self.children[v]) == 0

    def yield_edges(self) -> Iterator[tuple]:
        """rootに近い順にedgeを回すIterator"""
        N, ts, ch = self.N, self.ts_order, self.children
        if self._has_weight:
            wei, en = self.weight, self.encode
            for v in ts:
                for c in ch[v]:
                    yield (v,c,wei[en(v,c)])
        else:
            for v in ts:
                for c in ch[v]:
                    yield (v,c)
                    
    """ weight """
    #edge->weightにO(1)でアクセスできるようにdictで持つ
    def set_weight(self, edge):
        assert self._has_weight == False
        d = {}
        for u,v,c in edge:
            d[self.encode(u,v)] = d[self.encode(v,u)] = c
        self.weight = d
        self._has_weight = True

    def get_weight(self, u, v) -> int:
        return self.weight[self.encode(u, v)]
    
    """depth : rootからの深さ"""
    def get_depth(self, v) -> int:
        # obj.depth[v] と同じ.
        if not hasattr(self, "depth"):
            self.build_depth()    
        return self.depth[v]
    
    def build_depth(self):
        assert self._rooted
        N, ch, ts = self.N, self.children, self.ts_order
        depth = [0]*N
        for v in ts:
            for c in ch[v]:
                depth[c] = depth[v] + 1
        self.depth = depth


    """subtree_size : 部分木"""
    def build_des_size(self):
        assert self._rooted
        if hasattr(self, "des_size"):
            return
        N, ts, par = self.N, self.ts_order, self.parent
        des = [1]*N
        for i in range(N-1,0,-1):
            v = ts[i]
            p = par[v]
            des[p] += des[v]
        self.des_size = des


    """centroid : 重心分解"""
    def centroid_decomposition(self, build_dist=True):
        """
        centroid_id[i] : DFS的に重心分解をしたとき,
        頂点iを重心とする重心木が何番目に登場するか.

        頂点cenを重心とする重心木の頂点を探索する際は,頂点cenから,
        T.is_id_larger(v, cen)==True
        な頂点vのみを使って到達可能な頂点vを探索すればいい.

        centroid_dfs_order : centroid_id の逆順列.
        reveresed(centroid_dfs_order)順に重心木を探索することで
        より小さい重心木についての結果を用いたDPが可能.
        """
        if hasattr(self, "centroid_id"):
            return

        # 根に依存しないアルゴリズムなので根0にしていい.
        if not self._rooted:
            self.set_root(0)

        if not hasattr(self, "des_size"):
            self.build_des_size()

        # sizeは書き換えるのでコピーを使用.
        N, G, size = self.N, self.G, self.des_size[:]
        c_id, c_depth, c_par, c_dfs_order = [-1]*N, [-1]*N, [-1]*N, []

        stack = [(self.root, -1, 0)]
        # 重心を見つけたら,「重心分解後のその頂点が重心となる部分木」の
        # DFS順の順番, 深さ, 重心木における親にあたる部分木の重心を記録
        for order in range(N):
            v, prev, d = stack.pop()
            while True:
                for adj in G[v]:
                    if c_id[adj] == -1 and size[adj]*2 > size[v]:
                        # adjを今見ている部分木の根にし,sizeを書き換える.
                        size[v], size[adj], v = size[v]-size[adj], size[v], adj
                        break
                else:
                    break

            c_id[v], c_depth[v], c_par[v] = order, d, prev
            c_dfs_order.append(v)

            if size[v] > 1:
                for adj in G[v]:
                    if c_id[adj] == -1:
                        stack.append((adj, v, d+1))

        self.centroid_id, self.centroid_depth, self.centroid_parent, self.centroid_dfs_order = c_id, c_depth, c_par, c_dfs_order

        if build_dist == True:
            self.build_centroid_dist()

    def build_centroid_dist(self):
        """
        重心同士を結んだ木を重心分解木と呼ぶことにする.
        重心分解木のみを考えて解けるなら楽だが、
        「各重心木における重心(根)との距離」
        を求めるには元の辺の情報が必要.一方それさえ求めれば、
        重心分解木に対する考察だけで足りる問題が多い.
        """
        if hasattr(self, "centroid_dist"):
            return False
        if not hasattr(self, "centroid_id"):
            self.centroid_decomposition()

        N, G, c_depth = self.N, self.G ,self.centroid_depth
        is_id_larger = self.is_id_larger

        log = max(c_depth) + 1
        # dist[d][v] : vが深さdの重心木に属しているならその重心からの距離.

        dist = [[-1]*N for _ in range(log)]
        for cen in range(N):
            d = c_depth[cen]
            stack = [cen]
            dist[d][cen] = 0
            while stack:
                v = stack.pop()
                for adj in G[v]:
                    if dist[d][adj] == -1 and is_id_larger(adj, cen):
                        if self._has_weight:
                            dist[d][adj] = dist[d][v] + self.weight[self.encode(v, adj)]
                        else:
                            dist[d][adj] = dist[d][v] + 1
                        stack.append(adj)

        self.centroid_log, self.centroid_dist = log, dist


    def is_member_of_centroid_tree(self, v, c):
        # 頂点vが重心cの重心木に属するかを判定 O(logN)
        vs = self.get_higher_centroids_with_self(v)
        return c in vs

    def is_id_larger(self, u, v):
        # 重心cからBFSする時に、is_id_larger(adj, c)とすれば重心木内部を探索できる.
        return self.centroid_id[u] > self.centroid_id[v]

    def get_higher_centroids_with_self(self, c):
        # 頂点cが属する重心木の重心をサイズの昇順に列挙. O(logN)
        vs = []
        for d in range(self.centroid_depth[c], -1, -1):
            vs.append(c)
            c = self.centroid_parent[c]
        return vs

    def yield_centroid_children(self, v):
        # 頂点vを重心とする重心木における,
        # 「『vの子供を根とした部分木』と構成が同じ重心木の重心」を列挙する.
        # 「『重心木』の木」における「『vを重心とする重心木』の子の重心木」の重心 ともいえる.
        G, is_id_larger, c_par = self.G, self.is_id_larger, self.centroid_parent
        for ch in G[v]:
            if is_id_larger(ch, v):
                ch_cen = ch
                while c_par[ch_cen] != v:
                    ch_cen = c_par[ch_cen]
                yield (ch, ch_cen)

    def find_lowest_common_centroid(self, u, v):
        # 頂点u,vをどちらも含む最小の重心木を返す. O(logN)
        c_depth, c_par = self.centroid_depth, self.centroid_parent
        du, dv = c_depth[u], c_depth[v]
        if du > dv:
            u,v = v,u
            du,dv = dv,du
        for _ in range(dv - du):
            v = c_par[v]
        while u != v:
            u,v = c_par[u],c_par[v]
        return u


    def build_the_centroid(self):
        """ 全体の重心だけで十分な時用 O(N) """
        if not self._rooted:
            self.set_root(0)
        if hasattr(self, "the_centroid"):
            return False
        if hasattr(self, "centroid_id"):
            self.the_centroid = self.centroid_id[0]
            return True
        if not hasattr(self, "des_size"):
            self.build_des_size()
        
        N, ch, size = self.N, self.children, self.des_size
        v = self.root
        while True:
            for c in ch[v]:
                if size[c] > N // 2:
                    v = c
                    break
            else:
                self.the_centroid = v
                return True

    def get_the_centroid(self):
        if hasattr(self, "centroid_id"):
            return self.centroid_id[0]
        if not hasattr(self, "the_centroid"):
            self.build_the_centroid()
        return self.the_centroid


    """ tree dp """
    def dp_from_leaf(self, merge, e, add_root, push=lambda obj,data,dst,src:data):
        """
        チートシート
        部分木の大きさ : dp_from_leaf(lambda x,y:x+y, 0, lambda x,y,z:y+1)
        """
        assert self._rooted

        # pushで形整えたデータを親の単位元で初期化されたノードにmerge.
        # 子が全部mergeされたらadd_rootで自身の頂点の情報を追加.

        N, ts, par = self.N, self.ts_order, self.parent
        sub = [e] * N
        for i in range(N-1,-1,-1):
            v = ts[i]
            sub[v] = add_root(self, sub[v], v)
            p = par[v]
            if p != -1:
                sub[p] = merge(sub[p], push(self, sub[v], p, v))
        return sub

    def rerooting_dp(self, merge, e, add_root, push=lambda obj,data,dst,src:data):        
        """全方位木DP 途中で頂点を変更する"""
        if self._rooted == False:
            self.set_root(0)

        sub = self.dp_from_leaf(merge, e, add_root, push)

        N = self.N
        ts, par, ch = self.ts_order, self.parent, self.children
        
        compl, dp = [e]*N, [e]*N

        for i in range(N):
            v = ts[i]
            p, size = par[v], len(ch[v])
            left, right = [e]*size, [e]*size
            for j in range(size):
                c = ch[v][j]
                left[j] = merge(left[j-1] if j>0 else e, push(self, sub[c], v, c))
            for j in range(size-1,-1,-1):
                c = ch[v][j]
                right[j] = merge(right[j+1] if j<size-1 else e, push(self, sub[c], v, c))

            for j in range(size):
                c = ch[v][j]
                compl[c] = merge(compl[c], left[j-1] if j>0 else e)
                compl[c] = merge(compl[c], right[j+1] if j<size-1 else e)
                if p != -1:
                    compl[c] = merge(compl[c], push(self, compl[v], v, p))
                compl[c] = add_root(self, compl[c], v)

            if p != -1:
                dp[v] = merge(dp[v], push(self, compl[v], v, p))
            dp[v] = merge(dp[v], left[-1] if size else e)
            dp[v] = add_root(self, dp[v], v)

        return dp


    """ dist """
    def build_dist_from_root(self, op = lambda x,y : x+y):
        assert self._rooted
        if hasattr(self, "dist_from_root"):
            return
        N, ts, ch = self.N, self.ts_order, self.children
        dist = [0]*N
        if self._has_weight:
            wei, en = self.weight, self.encode
        else:
            wei, en = [1], lambda a,b:0
        for v in ts:
            for c in ch[v]:
                dist[c] = op(dist[v], wei[en(v, c)])
        self.dist_from_root = dist


    def calc_dist_from_a_node(self, v, op = lambda x,y : x+y):
        """ v -> children[v] のdist """
        N, G = self.N, self.G
        dist, que = [None]*N, [v]
        dist[v] = 0
        if self._has_weight:
            wei, en = self.weight, self.encode
        else:
            wei, en = [1], lambda a,b:0      
        while que:
            v = que.pop()
            for adj in G[v]:
                if dist[adj] is None:
                    dist[adj] = op(dist[v], wei[en(v, adj)])
                    que.append(adj)
        return dist

    def build_diameter(self):
        """直径を求める"""
        self.build_dist_from_root()
        if hasattr(self, "diameter"):
            return
        dist_r = self.dist_from_root
        v = dist_r.index(max(dist_r))
        dist_v = self.calc_dist_from_a_node(v)
        dia = max(dist_v)
        u = dist_v.index(dia)

        self.diameter, self.end_points_of_diameter = dia, [v, u]

    def get_diameter(self):
        """直径の取得"""
        if hasattr(self, "diameter"):
            return self.diameter
        self.build_diameter()
        return self.diameter

main()
"""==================fold line 1800=================="""
0