結果

問題 No.439 チワワのなる木
ユーザー Navier_BoltzmannNavier_Boltzmann
提出日時 2024-11-03 09:50:38
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 18,173 bytes
コンパイル時間 333 ms
コンパイル使用メモリ 82,480 KB
実行使用メモリ 142,480 KB
最終ジャッジ日時 2024-11-03 09:50:51
合計ジャッジ時間 12,179 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 57 ms
66,472 KB
testcase_01 AC 56 ms
67,052 KB
testcase_02 AC 60 ms
66,480 KB
testcase_03 AC 59 ms
67,264 KB
testcase_04 AC 56 ms
67,060 KB
testcase_05 AC 57 ms
66,504 KB
testcase_06 AC 58 ms
67,752 KB
testcase_07 AC 58 ms
67,472 KB
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 AC 1,159 ms
128,548 KB
testcase_25 AC 1,128 ms
140,448 KB
testcase_26 WA -
testcase_27 AC 440 ms
127,116 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

# import pypyjit
# pypyjit.set_param('max_unroll_recursion=-1')

from collections import *
from functools import *
from heapq import *
from itertools import *
import sys, math,random
# input = sys.stdin.buffer.readline
# sys.setrecursionlimit(10**6)

def cle(a, D):
    """
    Counts the number of elements in D that are less than or equal to a.

    Parameters:
    a (int): The value to compare against.
    D (list): A sorted list of integers.

    Returns:
    int: The count of elements in D that are less than or equal to a.
    """
    y = len(D) - 1
    x = 0
    if D[x] > a:
        return 0

    if D[y] <= a:
        return y + 1

    while y - x > 1:
        mid = (y + x) // 2
        if D[mid] <= a:
            x = mid
        else:
            y = mid
    return y
class cs_2d:
    """
    2D cumulative sum class.
    """

    def __init__(self, x):
        """
        Initializes the 2D cumulative sum array.

        Parameters:
        x (list of list of int): A 2D list of integers.
        """
        n = len(x)
        m = len(x[0])
        self.n = n
        self.m = m

        tmp = [0] * ((n + 1) * (m + 1))
        for i in range(n):
            for j in range(m):
                tmp[m * (i + 1) + j + 1] = (
                    tmp[m * (i + 1) + j] + tmp[m * i + j + 1] - tmp[m * i + j] + x[i][j]
                )

        self.S = tmp

    def query(self, ix, jx, iy, jy):
        """
        Queries the sum of the submatrix from (ix, iy) to (jx, jy).

        Parameters:
        ix (int): Starting row index.
        jx (int): Ending row index.
        iy (int): Starting column index.
        jy (int): Ending column index.

        Returns:
        int: The sum of the submatrix.
        """
        return (
            self.S[self.m * jx + jy]
            - self.S[self.m * jx + iy]
            - self.S[self.m * ix + jy]
            + self.S[self.m * ix + iy]
        )
class prime_factorize:
    """
    Class for prime factorization and related operations.
    """

    def __init__(self, M=10**6):
        """
        Initializes the sieve for prime factorization.

        Parameters:
        M (int): The maximum number to factorize.
        """
        self.sieve = [-1] * (M + 1)
        self.sieve[1] = 1
        self.p = [False] * (M + 1)
        self.mu = [1] * (M + 1)

        for i in range(2, M + 1):
            if self.sieve[i] == -1:
                self.p[i] = True

                i2 = i**2
                for j in range(i2, M + 1, i2):
                    self.mu[j] = 0

                for j in range(i, M + 1, i):
                    self.sieve[j] = i
                    self.mu[j] *= -1

    def factors(self, x):
        """
        Returns the prime factors of x.

        Parameters:
        x (int): The number to factorize.

        Returns:
        list: A list of prime factors of x.
        """
        tmp = []
        while self.sieve[x] != x:
            tmp.append(self.sieve[x])
            x //= self.sieve[x]
        tmp.append(self.sieve[x])
        return tmp

    def divisors(self, x):
        """
        Returns all divisors of x.

        Parameters:
        x (int): The number to find divisors for.

        Returns:
        list: A sorted list of all divisors of x.
        """
        C = Counter(self.factors(x))
        tmp = []
        for p in product(*[[pow(k, i) for i in range(v + 1)] for k, v in C.items()]):
            res = 1
            for pp in p:
                res *= pp
            tmp.append(res)
        tmp.sort()
        return tmp

    def is_prime(self, x):
        """
        Checks if x is a prime number.

        Parameters:
        x (int): The number to check.

        Returns:
        bool: True if x is prime, False otherwise.
        """
        return self.p[x]

    def mobius(self, x):
        """
        Returns the Möbius function value of x.

        Parameters:
        x (int): The number to find the Möbius function value for.

        Returns:
        int: The Möbius function value of x.
        """
        return self.mu[x]
class combination:
    """
    Class for computing combinations (nCr) modulo p.
    """

    def __init__(self, N, p):
        """
        Initializes the combination class.

        Parameters:
        N (int): The maximum value of n.
        p (int): The modulus.
        """
        self.fact = [1, 1]  # fact[n] = (n! mod p)
        self.factinv = [1, 1]  # factinv[n] = ((n!)^(-1) mod p)
        self.inv = [0, 1]  # factinv calculation
        self.p = p

        for i in range(2, N + 1):
            self.fact.append((self.fact[-1] * i) % p)
            self.inv.append((-self.inv[p % i] * (p // i)) % p)
            self.factinv.append((self.factinv[-1] * self.inv[-1]) % p)

    def cmb(self, n, r):
        """
        Computes the combination (nCr) modulo p.

        Parameters:
        n (int): The total number of items.
        r (int): The number of items to choose.

        Returns:
        int: The value of nCr modulo p.
        """
        if (r < 0) or (n < r):
            return 0
        r = min(r, n - r)
        return self.fact[n] * self.factinv[r] * self.factinv[n - r] % self.p
def md(n):
    """
    Returns all divisors of n.

    Parameters:
    n (int): The number to find divisors for.

    Returns:
    list: A sorted list of all divisors of n.
    """
    lower_divisors, upper_divisors = [], []
    i = 1
    while i * i <= n:
        if n % i == 0:
            lower_divisors.append(i)
            if i != n // i:
                upper_divisors.append(n // i)
        i += 1
    return lower_divisors + upper_divisors[::-1]
class DSU:
    """
    Disjoint Set Union (Union-Find) class.
    """

    def __init__(self, n):
        """
        Initializes the DSU.

        Parameters:
        n (int): The number of elements.
        """
        self._n = n
        self.parent_or_size = [-1] * n
        self.member = [[i] for i in range(n)]
        self._max = [i for i in range(n)]
        self._min = [i for i in range(n)]

    def merge(self, a, b):
        """
        Merges the sets containing a and b.

        Parameters:
        a (int): An element in the first set.
        b (int): An element in the second set.

        Returns:
        int: The leader of the merged set.
        """
        assert 0 <= a < self._n
        assert 0 <= b < self._n
        x, y = self.leader(a), self.leader(b)
        if x == y:
            return x
        if -self.parent_or_size[x] < -self.parent_or_size[y]:
            x, y = y, x
        self.parent_or_size[x] += self.parent_or_size[y]
        self._max[x] = max(self._max[x],self._max[y])
        self._min[x] = min(self._min[x],self._min[y])
        for tmp in self.member[y]:
            self.member[x].append(tmp)
        self.parent_or_size[y] = x
        return x
    def get_max(self,x):
        return self._max[self.leader(x)]
    def get_min(self,x):
        return self._min[self.leader(x)]

    def members(self, a):
        """
        Returns the members of the set containing a.

        Parameters:
        a (int): An element in the set.

        Returns:
        list: A list of members in the set containing a.
        """
        return self.member[self.leader(a)]

    def same(self, a, b):
        """
        Checks if a and b are in the same set.

        Parameters:
        a (int): An element in the first set.
        b (int): An element in the second set.

        Returns:
        bool: True if a and b are in the same set, False otherwise.
        """
        assert 0 <= a < self._n
        assert 0 <= b < self._n
        return self.leader(a) == self.leader(b)

    def leader(self, a):
        """
        Finds the leader of the set containing a.

        Parameters:
        a (int): An element in the set.

        Returns:
        int: The leader of the set containing a.
        """
        assert 0 <= a < self._n
        if self.parent_or_size[a] < 0:
            return a
        self.parent_or_size[a] = self.leader(self.parent_or_size[a])
        return self.parent_or_size[a]

    def size(self, a):
        """
        Returns the size of the set containing a.

        Parameters:
        a (int): An element in the set.

        Returns:
        int: The size of the set containing a.
        """
        assert 0 <= a < self._n
        return -self.parent_or_size[self.leader(a)]

    def groups(self):
        """
        Returns all sets as a list of lists.

        Returns:
        list: A list of lists, where each list contains the members of a set.
        """
        leader_buf = [self.leader(i) for i in range(self._n)]
        result = [[] for _ in range(self._n)]
        for i in range(self._n):
            result[leader_buf[i]].append(i)
        return [r for r in result if r != []]
class SegTree:
    """
    Segment Tree class.
    """

    def __init__(self, init_val, segfunc, ide_ele):
        """
        Initializes the Segment Tree.

        Parameters:
        init_val (list): The initial values for the leaves of the tree.
        segfunc (function): The function to use for segment operations.
        ide_ele (any): The identity element for the segment function.
        """
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1 << (n - 1).bit_length()
        self.tree = [ide_ele] * 2 * self.num
        # Set the initial values to the leaves
        for i in range(n):
            self.tree[self.num + i] = init_val[i]
        # Build the tree
        for i in range(self.num - 1, 0, -1):
            self.tree[i] = segfunc(self.tree[2 * i], self.tree[2 * i + 1])

    def update(self, k, x):
        """
        Updates the k-th value to x.

        Parameters:
        k (int): The index to update (0-indexed).
        x (any): The new value.
        """
        k += self.num
        self.tree[k] = x
        while k > 1:
            tk = k >> 1
            self.tree[tk] = self.segfunc(self.tree[tk << 1], self.tree[(tk << 1) + 1])
            k >>= 1

    def get(self, x):
        return self.tree[x + self.num]

    def query(self, l, r):
        """
        Queries the segment function result for the range [l, r).

        Parameters:
        l (int): The start index (0-indexed).
        r (int): The end index (0-indexed).

        Returns:
        any: The result of the segment function for the range [l, r).
        """
        res_l = self.ide_ele
        res_r = self.ide_ele

        l += self.num
        r += self.num
        while l < r:
            if l & 1:
                res_l = self.segfunc(res_l, self.tree[l])
                l += 1
            if r & 1:
                res_r = self.segfunc(self.tree[r - 1], res_r)
            l >>= 1
            r >>= 1
        res = self.segfunc(res_l, res_r)
        return res
class RSQandRAQ():
    """区間加算、区間取得クエリをそれぞれO(logN)で答える
    add: 区間[l, r)にvalを加える
    query: 区間[l, r)の和を求める
    l, rは0-indexed
    """

    def __init__(self, n, mod=None):
        self.n = n
        self.bit0 = [0] * (n + 1)
        self.bit1 = [0] * (n + 1)
        self.mod = mod
    def _add(self, bit, i, val):
        i = i + 1
        while i <= self.n:
            if self.mod is None:
                bit[i] += val
            else:
                bit[i] = (bit[i]+val)%self.mod
            i += i & -i

    def _get(self, bit, i):
        s = 0
        while i > 0:
            if self.mod is None:
                s += bit[i]
            else:
                s = (s + bit[i])%self.mod
            i-= i & -i
        return s

    def add(self, l, r, val):
        """区間[l, r)にvalを加える"""
        self._add(self.bit0, l, -val * l)
        self._add(self.bit0, r,  val * r)
        self._add(self.bit1, l,  val)
        self._add(self.bit1, r, -val)

    def query(self, l, r):
        """区間[l, r)の和を求める"""
        _res = (self._get(self.bit0, r) + r * self._get(self.bit1, r)
            - self._get(self.bit0, l) - l * self._get(self.bit1, l) )
        if self.mod is None:
            return _res
        else:
            return _res%self.mod
class Dinic:
    def __init__(self, n):
        self.n = n
        self.links = [[] for _ in range(n)]
        self.depth = None
        self.progress = None
 
    def add_link(self, _from, to, cap):
        self.links[_from].append([cap, to, len(self.links[to])])
        self.links[to].append([0, _from, len(self.links[_from]) - 1])
 
    def bfs(self, s):
        depth = [-1] * self.n
        depth[s] = 0
        q = deque([s])
        while q:
            v = q.popleft()
            for cap, to, rev in self.links[v]:
                if cap > 0 and depth[to] < 0:
                    depth[to] = depth[v] + 1
                    q.append(to)
        self.depth = depth
 
    def dfs(self, v, t, flow):
        if v == t:
            return flow
        links_v = self.links[v]
        for i in range(self.progress[v], len(links_v)):
            self.progress[v] = i
            cap, to, rev = link = links_v[i]
            if cap == 0 or self.depth[v] >= self.depth[to]:
                continue
            d = self.dfs(to, t, min(flow, cap))
            if d == 0:
                continue
            link[0] -= d
            self.links[to][rev][0] += d
            return d
        return 0
 
    def max_flow(self, s, t):
        flow = 0
        while True:
            self.bfs(s)
            if self.depth[t] < 0:
                return flow
            self.progress = [0] * self.n
            current_flow = self.dfs(s, t, float('inf'))
            while current_flow > 0:
                flow += current_flow
                current_flow = self.dfs(s, t, float('inf'))

class HLD():
    
    ### HL分解をしてIDを振りなおしたものに対して、パスに含まれる区間を返す
    ### SegTreeにのせる配列はIDを並び替えたもの
    
    def __init__(self,e,root=0):
                
        self.N = len(e)
        self.e = e
        par = [-1]*self.N
        sub = [-1]*self.N
        self.root = root
        dist = [-1]*self.N
        v = deque()
        dist[root]=0
        v.append(root)
        while v:
            x = v.popleft()
            for ix in e[x]:
                if dist[ix] !=-1:
                    continue
                dist[ix] = dist[x] + 1
                v.append(ix)
        
        H = [(-dist[i],i) for i in range(self.N)]
        H.sort()
        for h,i in H:
            tmp = 1
            for ix in e[i]:
                if sub[ix] == -1:
                    par[i]= ix
                else:
                    tmp += sub[ix]
            sub[i] = tmp
        
        
        self.ID = [-1]*self.N
        self.ID[self.root]=0
        self.HEAD = [-1]*self.N
        head = [-1]*self.N
        self.PAR = [-1]*self.N
        visited = [False]*self.N
        self.HEAD[0]=0
        head[self.root]=0
        depth = [-1]*self.N
        depth[self.root]=0
        self.DEPTH = [-1]*self.N
        self.DEPTH[0]=0
        cnt = 0
        v = deque([self.root])
        self.SUB = [0]*self.N
        self.SUB[0] = self.N
        while v:
            x = v.popleft()
            visited[x]=True
            self.ID[x]=cnt
            cnt += 1
            n = len(self.e[x])
            tmp = [(sub[ix],ix) for ix in self.e[x]]
            tmp.sort()
            flg = 0
            if x==self.root:
                flg -= 1
            for _,ix in tmp:
                flg += 1
                if visited[ix]:
                    continue
                v.appendleft(ix)
                if flg==n-1:
                    head[ix] = head[x]
                    depth[ix] = depth[x]
                else:
                    head[ix] = ix
                    depth[ix] = depth[x]+1
        
        for i in range(self.N):
            self.PAR[self.ID[i]] = self.ID[par[i]]
            self.HEAD[self.ID[i]] = self.ID[head[i]]
            self.DEPTH[self.ID[i]] = depth[i]
            self.SUB[self.ID[i]] = sub[i]
        
    def path_query(self,l,r):
        L = self.ID[l]
        R = self.ID[r]
        res = []
        if self.DEPTH[L]<self.DEPTH[R]:
            L,R = R,L
        
        while self.DEPTH[L] != self.DEPTH[R]:
            tmp = (self.HEAD[L],L+1)
            res.append(tmp)
            L = self.PAR[self.HEAD[L]]
        
        while self.HEAD[L] != self.HEAD[R]:
            tmp = (self.HEAD[L],L+1)
            res.append(tmp)
            L = self.PAR[self.HEAD[L]]            
            tmp = (self.HEAD[R],R+1)
            res.append(tmp)
            R = self.PAR[self.HEAD[R]]        
        
        if L>R:
            L,R = R,L
            
        tmp = (L,R+1)
        res.append(tmp)
        
        return res
        
    def sub_query(self,k):
        
        K = self.ID[k]
        
        return (K,K+self.SUB[K])
    
class HLD_SegTree:


    def __init__(self,e,init_val,segfunc,ide_ele,root=0):

        self.hld = HLD(e,root = root)
        self.ID = self.hld.ID[:]
        self.N = len(e)
        A = [0]*self.N
        for i,idx in enumerate(self.ID):
            A[idx] = init_val[i]
        self.seg = SegTree(A,segfunc,ide_ele)
        self.segfunc = segfunc
        self.ide_ele = ide_ele

    def path_query(self,l,r):
        res = self.ide_ele
        for _l,_r in self.hld.path_query(l,r):
            res = self.segfunc(res,self.seg.query(_l,_r))
        return res
    def sub_query(self,x):
        
        _l,_r = self.hld.sub_query(x)
        return self.seg.query(_l,_r)


N = int(input())
S = list(input())
e = [[] for _ in range(N)]
for _ in range(N-1):
    a,b = map(int,input().split())
    a -= 1
    b -= 1
    e[a].append(b)
    e[b].append(a)
C = [int(s=='c') for s in S]
hld = HLD_SegTree(e,C,lambda x,y:x+y,0,0)
hld2 = HLD_SegTree(e,[1-c for c in C],lambda x,y:x+y,0,0)
M = sum(C)
ans = 0

for i in range(N):
    if S[i]=='c':
        continue
    cc0 = hld.sub_query(i)
    mm0 = hld2.sub_query(i) - 1
    cc1 = M - cc0
    mm1 = N - 1 - cc0 - cc1 - mm0
    ans = (ans + cc1*mm0 + cc0*mm1)
print(ans)
0