結果

問題 No.196 典型DP (1)
ユーザー Navier_BoltzmannNavier_Boltzmann
提出日時 2023-08-04 00:50:18
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 228 ms / 2,000 ms
コード長 3,850 bytes
コンパイル時間 166 ms
コンパイル使用メモリ 82,688 KB
実行使用メモリ 109,516 KB
最終ジャッジ日時 2024-10-13 21:53:30
合計ジャッジ時間 7,448 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 41
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import *
from itertools import *
from functools import *
from heapq import *
import sys,math
input = sys.stdin.readline
N,K = map(int,input().split())
e = [[] for _ in range(N)]
for _ in range(N-1):
    a,b = map(int,input().split())
    e[a].append(b)
    e[b].append(a)

class HLD():
    
    ### HL分解をしてIDを振りなおしたものに対して、パスに含まれる区間を返す
    ### SegTreeにのせる配列はIDを並び替えたもの
    
    
    def __init__(self,e,root=0):
        
        
        self.N = len(e)
        self.e = e
        par = [-1]*N
        sub = [-1]*N
        self.root = root
        dist = [-1]*N
        v = deque()
        dist[root]=0
        v.append(root)
        while v:
            x = v.popleft()
            for ix in e[x]:
                if dist[ix] !=-1:
                    continue
                dist[ix] = dist[x] + 1
                v.append(ix)
        
        H = [(-dist[i],i) for i in range(N)]
        H.sort()
        for h,i in H:
            tmp = 1
            for ix in e[i]:
                if sub[ix] == -1:
                    par[i]= ix
                else:
                    tmp += sub[ix]
            sub[i] = tmp
        
        
        self.ID = [-1]*N
        self.ID[self.root]=0
        self.HEAD = [-1]*N
        head = [-1]*N
        self.PAR = [-1]*N
        visited = [False]*N
        self.HEAD[0]=0
        head[self.root]=0
        depth = [-1]*N
        depth[self.root]=0
        self.DEPTH = [-1]*N
        self.DEPTH[0]=0
        cnt = 0
        v = deque([self.root])
        self.SUB = [0]*N
        self.SUB[0] = N
        while v:
            x = v.popleft()
            visited[x]=True
            self.ID[x]=cnt
            cnt += 1
            n = len(self.e[x])
            tmp = [(sub[ix],ix) for ix in self.e[x]]
            tmp.sort()
            flg = 0
            if x==self.root:
                flg -= 1
            for _,ix in tmp:
                flg += 1
                if visited[ix]:
                    continue
                v.appendleft(ix)
                if flg==n-1:
                    head[ix] = head[x]
                    depth[ix] = depth[x]
                else:
                    head[ix] = ix
                    depth[ix] = depth[x]+1
        
        for i in range(self.N):
            self.PAR[self.ID[i]] = self.ID[par[i]]
            self.HEAD[self.ID[i]] = self.ID[head[i]]
            self.DEPTH[self.ID[i]] = depth[i]
            self.SUB[self.ID[i]] = sub[i]
        
    def path_query(self,l,r):
        L = self.ID[l]
        R = self.ID[r]
        res = []
        if self.DEPTH[L]<self.DEPTH[R]:
            L,R = R,L
        
        while self.DEPTH[L] != self.DEPTH[R]:
            tmp = (self.HEAD[L],L+1)
            res.append(tmp)
            L = self.PAR[self.HEAD[L]]
        
        while self.HEAD[L] != self.HEAD[R]:
            tmp = (self.HEAD[L],L+1)
            res.append(tmp)
            L = self.PAR[self.HEAD[L]]            
            tmp = (self.HEAD[R],R+1)
            res.append(tmp)
            R = self.PAR[self.HEAD[R]]        
        
        if L>R:
            L,R = R,L
            
        tmp = (L,R+1)
        res.append(tmp)
        
        return res
        
    def sub_query(self,k):
        
        K = self.ID[k]
        
        return (K,K+self.SUB[K])
        
        
mod = 10**9 + 7
hld = HLD(e)
ID = hld.ID[:]
inv = {v:i for i,v in enumerate(ID)}
dp = [[0]*(N+1) for _ in range(N+1)]
dp[0][0]=1
for i in range(N):
    
    sub = hld.SUB[i]
    dp[i][0]=1
    # print(sub)
    for j in range(N-sub+1):
        dp[i+sub][j+sub] += dp[i][j]
        dp[i+sub][j+sub] %= mod
    for j in range(N+1):
        dp[i+1][j] += dp[i][j]
        dp[i+1][j] %= mod
# print(sum(dp[i][K] for i in range(N+1)))
print(dp[-1][K])
0