結果

問題 No.898 tri-βutree
ユーザー ch1channnch1channn
提出日時 2023-09-06 06:32:47
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 14,832 bytes
コンパイル時間 867 ms
コンパイル使用メモリ 82,552 KB
実行使用メモリ 121,708 KB
最終ジャッジ日時 2024-06-24 01:13:23
合計ジャッジ時間 21,936 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 639 ms
121,708 KB
testcase_01 AC 69 ms
72,508 KB
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
if sys.platform =='ios':
	import clipboard
	a=clipboard.get()
	a = a.split('\n')
	text = '\n'.join(a)
	with open('input_file.txt','w') as f:
		f.write(text)
	sys.stdin = open('input_file.txt')
sys.setrecursionlimit(410000000)
stdin = sys.stdin 
def ni():return int(ns())
def na():return list(map(int, stdin.readline().split()))
def ns():return stdin.readline().strip()
def nm():return map(int,input().split())
def na_1():return list(map(lambda x:int(x)*(-1), stdin.readline().split()))
def na_2():return list(map(lambda x:int(x)-1, stdin.readline().split()))
rnage=range
rnge=range
rage = range
rnag=range

from collections import *
from bisect import *
from math import *
from heapq import *
from itertools import *
def popcount(n):
	c=(n&0x5555555555555555)+((n>>1)&0x5555555555555555)
	c=(c&0x3333333333333333)+((c>>2)&0x3333333333333333)
	c=(c&0x0f0f0f0f0f0f0f0f)+((c>>4)&0x0f0f0f0f0f0f0f0f)
	c=(c&0x00ff00ff00ff00ff)+((c>>8)&0x00ff00ff00ff00ff)
	c=(c&0x0000ffff0000ffff)+((c>>16)&0x0000ffff0000ffff)
	c=(c&0x00000000ffffffff)+((c>>32)&0x00000000ffffffff)
	return c
"""
mod = 10**9+7
M= (10**5)*3+1 
fac= [1]*M
ninv= [1]*M
finv= [1]*M
for i in range(2,M):
  fac[i] = fac[i-1]*i%mod
  ninv[i] = (-(mod//i)*ninv[mod%i])%mod
  finv[i] = finv[i-1]*ninv[i]%mod
  
def binom(n,k):
  if n<0 or k<0:
    return 0
  if k>n:
    return 0
  return (fac[n]*finv[k]%mod)*finv[n-k]%mod


def nHk(n, k):
	return binom(n + k - 1, k)
"""
def nCk(n, k):
	k = min(k, n - k)
	ret = 1
	for i in range(n, n - k, -1): ret *= i
	for i in range(2, k + 1): ret //= i
	return ret

def nHk(n, k):
	return nCk(n + k - 1, k)

def nC2(x):
	return x*(x-1)//2
def is_prime(n):
    if n == 1: return False

    for k in range(2, int(math.sqrt(n)) + 1):
        if n % k == 0:
            return False

    return True

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n
        self.group = n
 
    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]
 
    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
 
        if x == y:
            return
        self.group -= 1
        if self.parents[x] > self.parents[y]:
            x, y = y, x
 
        self.parents[x] += self.parents[y]
        self.parents[y] = x
 
    def size(self, x):
        return -self.parents[self.find(x)]
 
    def same(self, x, y):
        return self.find(x) == self.find(y)
 
    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]
 
    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]
 
    def group_count(self):
        return self.group
 
    def all_group_members(self):
        dic = {r:[] for r in self.roots()}
        for i in range(self.n):
            dic[self.find(i)].append(i)
        return dic
 
    def __str__(self):
        return '\n'.join('{}: {}'.format(r, self.members(r)) for r in self.roots())
class SOE:
  def __init__(self,m):
    self.sieve=[-1]*(m+1)
    self.prime=[]
    for i in range(2,m+1):
      if self.sieve[i]==-1:
        self.prime.append(i)
        self.sieve[i]=i
        j=2*i
        while j<=m:
          self.sieve[j]=i
          j+=i
  
  def primes(self):
    # get primes
    return self.prime
  
  def fact(self,n):
    # prime factorization
    d=[]
    while n!=1:
      p=self.sieve[n]
      d.append(p)
      while n%p==0:
        n//=p
    return d
  
def div(n):
	lower,upper = [],[]
	i = 1
	while i*i <= n:
		if n%i == 0:
			lower.append(i)
			if i!=n//i:
				upper.append(n//i)
		i += 1
	return lower+upper[::-1]

def factorize(x):
	yaku = []
	for i in range(2,int(x**0.5)+1):
		if x%i == 0:
			while x%i == 0:
				x //= i
				yaku.append(i)
	if x != 1:
		yaku.append(x)
	return yaku
#[2,3]この形
		
def fact(n):
  res=n
  a=[]
  i=2
  while i*i<=res:
    if res%i==0:
      cnt=0
      while res%i==0:
        cnt+=1
        res//=i
      a.append((i,cnt))
    i+=1
  if res!=1:
    a.append((res,1))
  return a
"""
[(2, 1), (3, 1)]この形
"""		

def nc2(x):
	return (x*(x-1)//2)
	
import random
class RollingHash:    
    mask30 = (1 << 30) - 1
    mask31 = (1 << 31) - 1
    MOD = (1 << 61) - 1
    Base = None
    pw = [1]
    
    def __init__(self, S):
        if RollingHash.Base is None:
            RollingHash.Base = random.randrange(129, 1 << 30)
        for i in range(len(RollingHash.pw), len(S) + 1):
            RollingHash.pw.append(RollingHash.CalcMod(RollingHash.Mul(RollingHash.pw[i - 1], self.__class__.Base)))
        
        self.hash = [0] * (len(S) + 1)
        for i, s in enumerate(S, 1):
            self.hash[i] = RollingHash.CalcMod(RollingHash.Mul(self.hash[i - 1], RollingHash.Base) + ord(s))
 
    def get(self, l, r):
        return RollingHash.CalcMod(self.hash[r] - RollingHash.Mul(self.hash[l], RollingHash.pw[r - l]))
 
    def Mul(l, r):
        lu = l >> 31
        ld = l & RollingHash.mask31
        ru = r >> 31
        rd = r & RollingHash.mask31
        middlebit = ld * ru + lu * rd
        return ((lu * ru) << 1) + ld * rd + \
            ((middlebit & RollingHash.mask30) << 31) + (middlebit >> 30)
 
    def CalcMod(val):
        if val < 0:
            val %= RollingHash.MOD
        val = (val & RollingHash.MOD) + (val >> 61)
        if val > RollingHash.MOD:
            val -= RollingHash.MOD
        return val
        
# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
import math
from bisect import bisect_left, bisect_right, insort
from typing import Generic, Iterable, Iterator, TypeVar, Optional, List
T = TypeVar('T')

class SortedMultiset(Generic[T]):
    BUCKET_RATIO = 50
    REBUILD_RATIO = 170

    def _build(self, a=None) -> None:
        "Evenly divide `a` into buckets."
        if a is None: a = list(self)
        size = self.size = len(a)
        bucket_size = int(math.ceil(math.sqrt(size / self.BUCKET_RATIO)))
        self.a = [a[size * i // bucket_size : size * (i + 1) // bucket_size] for i in range(bucket_size)]
    
    def __init__(self, a: Iterable[T] = []) -> None:
        "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
        a = list(a)
        if not all(a[i] <= a[i + 1] for i in range(len(a) - 1)):
            a = sorted(a)
        self._build(a)

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i: yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedMultiset" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _find_bucket(self, x: T) -> List[T]:
        "Find the bucket which should contain x. self must not be empty."
        for a in self.a:
            if x <= a[-1]: return a
        return a

    def __contains__(self, x: T) -> bool:
        if self.size == 0: return False
        a = self._find_bucket(x)
        i = bisect_left(a, x)
        return i != len(a) and a[i] == x

    def count(self, x: T) -> int:
        "Count the number of x."
        return self.index_right(x) - self.index(x)

    def add(self, x: T) -> None:
        "Add an element. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return
        a = self._find_bucket(x)
        insort(a, x)
        self.size += 1
        if len(a) > len(self.a) * self.REBUILD_RATIO:
            self._build()

    def discard(self, x: T) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a = self._find_bucket(x)
        i = bisect_left(a, x)
        if i == len(a) or a[i] != x: return False
        a.pop(i)
        self.size -= 1
        if len(a) == 0: self._build()
        return True

    def lt(self, x: T) -> Optional[T]:
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Optional[T]:
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Optional[T]:
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Optional[T]:
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, x: int) -> T:
        "Return the x-th element, or IndexError if it doesn't exist."
        if x < 0: x += self.size
        if x < 0: raise IndexError
        for a in self.a:
            if x < len(a): return a[x]
            x -= len(a)
        raise IndexError

    def index(self, x: T) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans

oo = 1<<60
ans = 0
cnt = 0
res = oo
def dijkstra(n, s, G):
	hq = [(0, s)]
	cost = [float('inf')] * n
	cost[s] = 0
	
	while hq:
		if c > cost[v]:
			continue
		for nv,d in G[v]:
			tmp = d + cost[v]
			if tmp < cost[nv]:
				cost[nv] = tmp
				heappush(hq, (tmp,nv))
	return cost



class Topological_Sort:
    def __init__(self, N: int):
        """ N 頂点からなる空グラフを用意する.
 
        N: int
        """
        self.N=N
        self.arc=[[] for _ in  range(N)]
        self.rev=[[] for _ in range(N)]
 
    def add_arc(self, source: int, target: int):
        """ 有向辺 source → taeget を追加する.
 
        """
        self.arc[source].append(target)
        self.rev[target].append(source)
 
    def sort(self):
        """ トポロジカルソートを求める.
 
        [Ouput]
        存在する → トポロジカルソート
        存在しない → None
        """
 
        in_deg=[len(self.rev[x]) for x in range(self.N)]
        Q=[x for x in range(self.N) if in_deg[x]==0]
 
        S=[]
        while Q:
            u=Q.pop()
            S.append(u)
 
            for v in self.arc[u]:
                in_deg[v]-=1
                if in_deg[v]==0:
                    Q.append(v)
 
        return S if len(S)==self.N else None
 
    def is_DAG(self):
        """ DAG かどうかを判定する.
        """
        return self.sort()!=None
"""
class LowestCommonAncestor:
    def __init__(self,n):
        self._n=n;n=0
        while 2**(n/10)<self._n:n+=1
        self._logn=int(n/10+2)  #mathモジュールなしで構築
        self._depth=    [ 0 for _ in [0]*self._n]
        self._distance= [ 0 for _ in [0]*self._n]
        self._ancestor=[[-1 for _ in [0]*self._n] for k in [0]*self._logn]
        self._edge=     [[] for _ in [0]*self._n]

    def add_edge(self,u,v,w=1):  #頂点u,v間に重みwの辺を追加する
        self._edge[u].append((v,w))
        self._edge[v].append((u,w))

    def build(self,root=0):  #rootを指定し、その他の頂点に祖先情報を書き込む
        stack=[root]
        while stack:
            now=stack.pop()
            for nxt,w in self._edge[now]:
                if self._ancestor[0][nxt]!=now and self._ancestor[0][now]!=nxt:
                    self._ancestor[0][nxt]=now
                    self._depth[nxt]=   self._depth[now]   +1
                    self._distance[nxt]=self._distance[now]+w
                    stack.append(nxt)
        for k in range(1,self._logn):
            for i in range(self._n):
                if self._ancestor[k-1][i]==-1:
                     self._ancestor[k][i]=-1
                else:self._ancestor[k][i]=self._ancestor[k-1][self._ancestor[k-1][i]]

    def LCA(self,u,v):
        if self._depth[u]>self._depth[v]:u,v=v,u  #uが浅く、vが深い状態に
        for k in range(self._logn-1,-1,-1):  #vとuを同じ深度に
            if ((self._depth[v]-self._depth[u])>>k)&1:v=self._ancestor[k][v]
        if u==v:return u
        for k in range(self._logn-1,-1,-1):  #ギリギリ一致する直前まで祖先を辿る
            if self._ancestor[k][u]!=self._ancestor[k][v]:
                u,v=self._ancestor[k][u],self._ancestor[k][v]
        return self._ancestor[0][u]

    def distance(self,u,v):
        return self._distance[u]+self._distance[u]-2*self._distance[self.LCA(u,v)]
"""
import math

class LCA:
  def __init__(self,n):
    self._n=n
    self._logn=int(math.log2(self._n)+2)
    self._depth=[0]*self._n
    self._distance=[0]*self._n
    self._ancestor=[[-1]*self._n for k in range(self._logn)]
    self._edges=[[] for i in range(self._n)]
    
  def add_edge(self,u,v,w=1):
    self._edges[u].append((v,w))
    self._edges[v].append((u,w))
    
  def build(self,root=0):
    stack=[root]
    while len(stack):
      cur = stack.pop()
      for nxt,w in self._edges[cur]:
        if self._ancestor[0][nxt]!=cur and self._ancestor[0][cur]!=nxt:
          self._ancestor[0][nxt]=cur
          self._depth[nxt]=self._depth[cur]+1
          self._distance[nxt]=self._distance[cur]+w
          stack.append(nxt)
    
    for k in range(1,self._logn):
      for i in range(self._n):
        if self._ancestor[k-1][i]==-1:
          self._ancestor[k][i]=-1
        else:
          self._ancestor[k][i]=self._ancestor[k-1][self._ancestor[k-1][i]]
        
  def lca(self,u,v):
    if self._depth[u]>self._depth[v]:
      u,v=v,u
    
    for k in range(self._logn-1,-1,-1):
      if ((self._depth[v]-self._depth[u])>>k)&1:
        v=self._ancestor[k][v]
      
    if u==v:
      return u
    
    for k in range(self._logn-1,-1,-1):
      if self._ancestor[k][u]!=self._ancestor[k][v]:
        u=self._ancestor[k][u]
        v=self._ancestor[k][v]
    return self._ancestor[0][u]
  
  def distance(self,u,v):
    return self._distance[u]+self._distance[v]-2*self._distance[self.lca(u,v)]
  
						

N = ni()
tree = LCA(N)

for i in range(N-1):
	a,b,c = nm()
	tree.add_edge(a,b,c)

tree.build()
q = ni()
for _ in range(q):
	x,y,z = nm()
	d1 = tree.distance(x,y)
	d2 = tree.distance(x,z)
	Z = tree.lca(y,z)
	#print(Z)
	d3 = tree.distance(Z,x)
	print(d2+d1-d3)
0