結果
| 問題 |
No.1227 I hate ThREE
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-03-17 15:13:14 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 448 ms / 2,000 ms |
| コード長 | 2,670 bytes |
| コンパイル時間 | 666 ms |
| コンパイル使用メモリ | 82,768 KB |
| 実行使用メモリ | 123,552 KB |
| 最終ジャッジ日時 | 2024-11-14 18:25:39 |
| 合計ジャッジ時間 | 10,312 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 33 |
ソースコード
def main0(n,c,ab):
ki=[[] for _ in range(n)]
for a,b in ab:
a,b=a-1,b-1
ki[a].append(b)
ki[b].append(a)
mod=10**9+7
tree_order=[]
todo=[[0,-1]]
parent=[-1]*n
while todo:
v,p=todo.pop()
parent[v]=p
tree_order.append(v)
for nv in ki[v]:
if nv==p:continue
todo.append([nv,v])
tree_order.reverse()
# c<kx10^3ぐらいのときの解法をまず考える。
# 1~c -> 0~c-1とする
dp=[[1]*c for _ in range(n)]
# dp[v][i]頂点vの部分木を考え、頂点vに値iを入れる場合数
for v in tree_order:
# 親へ遷移
p=parent[v]
if p<0:continue
ary=[0]*c
for i in range(c):
if i-3>=0:
ary[i-3]+=dp[v][i]
ary[i-3]%=mod
if i+3<c:
ary[i+3]+=dp[v][i]
ary[i+3]%=mod
for i in range(c):
dp[p][i]*=ary[i]
dp[p][i]%=mod
ret=0
for x in dp[0]:
ret+=x
ret%=mod
return ret
def main1(n,c,ab):
if c<=2*n:return main0(n,c,ab)
ki=[[] for _ in range(n)]
for a,b in ab:
a,b=a-1,b-1
ki[a].append(b)
ki[b].append(a)
mod=10**9+7
dist=[0]*n
max_dist=0
tree_order=[]
todo=[[0,-1]]
parent=[-1]*n
while todo:
v,p=todo.pop()
parent[v]=p
tree_order.append(v)
for nv in ki[v]:
if nv==p:continue
todo.append([nv,v])
dist[nv]=dist[v]+1
max_dist=max(max_dist,dist[nv])
tree_order.reverse()
# 1~c -> 0~c-1 -> 0~6*n-1とする
dp=[[1]*(6*n) for _ in range(n)]
# dp[v][i]:頂点vの部分木についての解。頂点vに値iを入れる場合数。
for v in tree_order:
# 親へ遷移
p=parent[v]
if p<0:continue
ary=[0]*(6*n)
for i in range(6*n):
if i-3>=0:
ary[i-3]+=dp[v][i]
ary[i-3]%=mod
if i+3<6*n:
ary[i+3]+=dp[v][i]
ary[i+3]%=mod
for i in range(6*n):
dp[p][i]*=ary[i]
dp[p][i]%=mod
ret=0
for i in range(3*max_dist):
x=dp[0][i]
ret+=x
ret%=mod
x=dp[0][-1-i]
ret+=x
ret%=mod
ret+=dp[0][3*max_dist]*(c-3*(max_dist)*2)
ret%=mod
return ret
import sys
input=sys.stdin.readline
if __name__=='__main__':
n,c=map(int,input().split())
ab=[list(map(int,input().split())) for _ in range(n-1)]
print(main1(n,c,ab))
if __name__=='__main__1':
from random import randint,shuffle
for _ in range(100):
n=randint(2,50)
c=randint(4,200)
ab=[]
su=[1]
mi=list(range(2,n+1))
shuffle(mi)
while mi:
a=mi.pop()
b=su[randint(0,len(su))-1]
ab.append([a,b])
ret0=main0(n,c,ab)
ret1=main1(n,c,ab)
if ret0!=ret1:
print(n,c)
for x in ab:print(*x)
print((ret0,ret1))
break