結果
問題 | No.386 貪欲な領主 |
ユーザー |
![]() |
提出日時 | 2020-03-19 14:25:35 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
AC
|
実行時間 | 1,128 ms / 2,000 ms |
コード長 | 2,095 bytes |
コンパイル時間 | 266 ms |
コンパイル使用メモリ | 13,184 KB |
実行使用メモリ | 112,700 KB |
最終ジャッジ日時 | 2024-12-14 03:12:19 |
合計ジャッジ時間 | 13,562 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 16 |
ソースコード
#!/usr/bin/ python3.8# %%import sysread = sys.stdin.buffer.readreadline = sys.stdin.buffer.readlinereadlines = sys.stdin.buffer.readlinesimport numpy as npN = int(readline())graph = [[] for _ in range(N + 1)]for _ in range(N - 1):a, b = map(int, readline().split())a += 1b += 1graph[a].append(b)graph[b].append(a)cost = (0,) + tuple(int(readline()) for _ in range(N))Q = int(readline())ABM = np.array(read().split(), np.int64)A = ABM[::3] + 1B = ABM[1::3] + 1M = ABM[2::3]def EulerTour(graph, root=1):V = len(graph)par = [0] * Vdepth = [0] * Vdist = [0] * Vdist[1] = cost[1]depth[root] = 0tour = [root]st = [root]while st:x = st[-1]if not graph[x]:st.pop()tour.append(par[x])continuey = graph[x].pop()if y == par[x]:continuepar[y] = xdepth[y] = depth[x] + 1dist[y] = dist[x] + cost[y]st.append(y)tour.append(y)return par, tour, depth, distpar, tour, depth, dist = EulerTour(graph)Ltour = len(tour)tour_arr = np.array(tour)depth_arr = np.array(depth)tour_d = depth_arr[tour_arr]idx = np.arange(len(depth))idx[tour_arr] = np.arange(Ltour)sp = np.empty((Ltour.bit_length(), Ltour), np.int32)sp[0] = np.arange(Ltour)for n in range(1, Ltour.bit_length()):prev, width = sp[n - 1], 1 << (n - 1)x = prev[:-width]y = prev[width:]condition = tour_d[x] > tour_d[y]sp[n] = prevsp[n, :-width][condition] = y[condition]def LCA(A, B):AB = np.vstack([A, B]).TLR = idx[AB]LR.sort(axis=1)# [L,R] におけるRmQL = LR[:, 0]R = LR[:, 1]x = R - Ln = np.zeros_like(x) # 2^n <= R-Lfor _ in range(20):x >>= 1n[x > 0] += 1x = sp[n, L]y = sp[n, R - (1 << n) + 1]return np.where(tour_d[x] < tour_d[y], tour_arr[x], tour_arr[y])C = LCA(A, B)dist = np.array(dist)cost = np.array(cost)answer = ((dist[A] + dist[B] - 2 * dist[C] + cost[C]) * M).sum()print(answer)