結果
| 問題 |
No.1442 I-wate Shortest Path Problem
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-02-01 20:20:44 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 2,315 ms / 3,000 ms |
| コード長 | 2,638 bytes |
| コンパイル時間 | 439 ms |
| コンパイル使用メモリ | 81,792 KB |
| 実行使用メモリ | 206,720 KB |
| 最終ジャッジ日時 | 2024-10-11 12:26:53 |
| 合計ジャッジ時間 | 24,673 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 25 |
ソースコード
def main():
import sys
input=sys.stdin.buffer.readline
sys.setrecursionlimit(10**7)
from heapq import heappush,heappop
n,k=map(int,input().split())
g=[[]for _ in range(n+k)]
for _ in range(n-1):
a,b,c=map(int,input().split())
a-=1
b-=1
g[a].append((b,c))
g[b].append((a,c))
m=[0]*k
p=[0]*k
x=[[]for _ in range(k)]
for i in range(k):
m[i],p[i]=map(int,input().split())
x[i]=list(map(int,input().split()))
for j in range(m[i]):
x[i][j]-=1
q=int(input())
u=[0]*q
v=[0]*q
for i in range(q):
u[i],v[i]=map(int,input().split())
u[i]-=1
v[i]-=1
logn=17
dep=[0]*n
dis=[1<<60]*n
nxt=[[-1]*n for _ in range(logn)]
def dfs(cur,par,cur_dep,cur_dis):
dep[cur]=cur_dep
dis[cur]=cur_dis
nxt[0][cur]=par
for to,cost in g[cur]:
if to!=par:
dfs(to,cur,cur_dep+1,cur_dis+cost)
dfs(0,-1,0,0)
for j in range(logn-1):
for i in range(n):
if nxt[j][i]!=-1:
nxt[j+1][i]=nxt[j][nxt[j][i]]
def lca(a,b):
if dep[a]>dep[b]:
a,b=b,a
d=dep[b]-dep[a]
for j in range(logn):
if d>>j&1:
b=nxt[j][b]
if a==b:
return a
for j in range(logn-1,-1,-1):
if nxt[j][a]!=nxt[j][b]:
a=nxt[j][a]
b=nxt[j][b]
return nxt[0][a]
ans=[dis[u[i]]+dis[v[i]]-dis[lca(u[i],v[i])]*2 for i in range(q)]
dis_color=[[1<<60]*k for _ in range(k)]
dp=[[1<<60]*(n)for _ in range(k)]
dik=[]
border=100100
for i in range(k):
for j in range(m[i]):
dp[i][x[i][j]]=0
heappush(dik,x[i][j])
while dik:
d=heappop(dik)
cur=d%border
d//=border
if dp[i][cur]<d:
continue
for to,cost in g[cur]:
if dp[i][to]>dp[i][cur]+cost:
dp[i][to]=dp[i][cur]+cost
dik.append(dp[i][to]*border+to)
for j in range(i+1,k):
for c in range(m[j]):
if dis_color[i][j]>dp[i][x[j][c]]:
dis_color[i][j]=dp[i][x[j][c]]
for i in range(k):
dis_color[i][i]=0
for j in range(i+1,k):
dis_color[j][i]=dis_color[i][j]
for c in range(k):
for i in range(k):
for j in range(k):
if dis_color[i][j]>dis_color[i][c]+dis_color[c][j]+p[c]:
dis_color[i][j]=dis_color[i][c]+dis_color[c][j]+p[c]
for i in range(k):
for j in range(k):
if i==j:
dis_color[i][j]+=p[i]
else:
dis_color[i][j]+=p[i]+p[j]
for i in range(q):
res=ans[i]
for a in range(k):
for b in range(k):
if res>dp[a][u[i]]+dp[b][v[i]]+dis_color[a][b]:
res=dp[a][u[i]]+dp[b][v[i]]+dis_color[a][b]
print(res)
main()