結果

問題 No.1227 I hate ThREE
ユーザー uni_pythonuni_python
提出日時 2020-09-13 14:30:01
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 262 ms / 2,000 ms
コード長 2,452 bytes
コンパイル時間 186 ms
コンパイル使用メモリ 82,444 KB
実行使用メモリ 124,848 KB
最終ジャッジ日時 2024-06-12 19:04:59
合計ジャッジ時間 6,973 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input=sys.stdin.readline
def I(): return int(input())
def MI(): return map(int, input().split())
def LI(): return list(map(int, input().split()))


"""

根からスタートして+3 or -3 して根まで行きたい,その中で制約を満たさなきゃダメ(1<=p<=C)
N<=1000なので,全部+3とかしても高々+3000しか行かない.
ひとまず制約無視,根を0に固定して(2^(N-1)通り)
最大値,最小値をそれぞれで考えてダメなものを外す?...微妙

木dpとかしたい.O(NC)だけど,Cが大きく根の値が中途半端な時は全通り(=2^(N-1))可能か

根が3N以下ならC=6Nとみなして計算できる

"""
def main():
    mod=10**9+7
    N,C=MI()
    adj=[[]for _ in range(N)]
    for _ in range(N-1):
        a,b=MI()
        a-=1
        b-=1
        adj[a].append(b)
        adj[b].append(a)
        
    ch=[[]for _ in range(N)]
    P=[-1]*N
    import queue
    q=queue.Queue()
    q.put((0,-1))
    
    L=[] # bfsで見る順
    while not q.empty():
        v,p=q.get()
        L.append(v)
        for nv in adj[v]:
            if nv!=p:
                ch[v].append(nv)
                P[nv]=v
                q.put((nv,v))

    L=L[::-1] # dpで見る順
    
    def calc_X(X,flag):
        # 制約が1~X
        
        dp=[[1]*X for _ in range(N)]
        # dp[i][j]はi番目の頂点の値がjであるとき,此の頂点を根とした部分木の通り数
        for v in L:
            p=P[v]
            if p==-1:
                continue
            for j in range(X):
                temp=0
                if j+3<X:
                    temp+=dp[v][j+3]
                if j-3>=0:
                    temp+=dp[v][j-3]
                dp[p][j]*=temp
                dp[p][j]%=mod
                    
                    
        ans=0
        
        if flag==0:
            for j in range(X):
                ans=(ans+dp[0][j])%mod
                
        else:
            for j in range(3*N):
                ans=(ans+dp[0][j])%mod
                
            # for i in range(N):
            #     print(dp[i])
            
        return ans
        
    
    if C<=6*N:
        ans=calc_X(C,0)
        print(ans)
    else:
        temp=calc_X(6*N,1)
        rem=C-6*N
        temp2=pow(2,N-1,mod)
        
        ans= (temp*2 + temp2*rem)%mod
        
        print(ans)
        


                    
                    

main()
0