結果
問題 | No.1333 Squared Sum |
ユーザー |
|
提出日時 | 2024-11-06 02:20:33 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,767 ms / 2,000 ms |
コード長 | 4,002 bytes |
コンパイル時間 | 260 ms |
コンパイル使用メモリ | 82,016 KB |
実行使用メモリ | 206,292 KB |
最終ジャッジ日時 | 2024-11-06 02:21:19 |
合計ジャッジ時間 | 45,161 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 44 |
ソースコード
## https://yukicoder.me/problems/no/1333from collections import dequeMOD= 10 ** 9 + 7def main():N = int(input())next_nodes = [[] for _ in range(N)]for _ in range(N - 1):u, v, w = map(int, input().split())next_nodes[u - 1].append((v - 1, w))next_nodes[v - 1].append((u - 1, w))# 全方位木dpparents = [-2] * Nparents_index = [-2] * Nparents_index2 = [-2] * Ntotal_child_num = [0 for _ in range(N)]next_child_num = [[0] * len(next_nodes[v]) for v in range(N)]total_dist_sum = [0 for _ in range(N)]next_dist_sum = [[0] * len(next_nodes[v]) for v in range(N)]total_dist2_sum = [0 for _ in range(N)]next_dist2_sum = [[0] * len(next_nodes[v]) for v in range(N)]parents[0] = -1parents_index[0] = -1parents_index2[0] = -1stack = deque()stack.append((0, 0))while len(stack) > 0:v, index = stack.pop()while index < len(next_nodes[v]):w = next_nodes[v][index][0]if w == parents[v]:parents_index[v] = indexindex += 1continueparents[w] = vparents_index2[w] = indexstack.append((v, index + 1))stack.append((w, 0))breakif index == len(next_nodes[v]):p = parents[v]if p != -1:p_index2 = parents_index2[v]total_child_num[p] += 1 + total_child_num[v]next_child_num[p][p_index2] = 1 + total_child_num[v]w = next_nodes[p][p_index2][1]d1 = ((1 + total_child_num[v]) * w) % MODtotal_dist_sum[p] += (d1 + total_dist_sum[v]) % MODtotal_dist_sum[p] %= MODnext_dist_sum[p][p_index2] = (d1 + total_dist_sum[v]) % MODd2 = ((1 + total_child_num[v]) * w) % MODd2 *= wd2 %= MODd2_2 = (w * total_dist_sum[v]) % MODd2_2 *= 2d2_2 %= MODd = (total_dist2_sum[v] +d2_2) % MODd += d2d %= MODtotal_dist2_sum[p] += dtotal_dist2_sum[p] %= MODnext_dist2_sum[p][p_index2] = d# queue Partqueue = deque()queue.append((0, 0, 0, 0))while len(queue) > 0:v, c_c_num, c_d_sum, c_d2_sum = queue.popleft()if parents[v] != -1:p = parents[v]p_index = parents_index[v]total_child_num[v] += c_c_numnext_child_num[v][p_index] = c_c_numtotal_dist_sum[v] += c_d_sumtotal_dist_sum[v] %= MODnext_dist_sum[v][p_index] = c_d_sumtotal_dist2_sum[v] += c_d2_sumtotal_dist2_sum[v] %= MODnext_dist2_sum[v][p_index] = c_d2_sumfor i in range(len(next_nodes[v])):v0 = next_nodes[v][i][0]if v0 == parents[v]:continuec_num = total_child_num[v] - next_child_num[v][i]d_sum = (total_dist_sum[v] - next_dist_sum[v][i]) % MODd2_sum = (total_dist2_sum[v] - next_dist2_sum[v][i]) % MODw = next_nodes[v][i][1]next_c_num = c_num + 1next_d_sum = ((1 + c_num) * w) % MODnext_d_sum += d_sumnext_d_sum %= MODd2 = ((1 + c_num) * w) % MODd2 *= wd2 %= MODd2_2 = (w * d_sum) % MODd2_2 *= 2d2_2 %= MODd = (d2_sum + d2_2) % MODd += d2d %= MODnext_d2_sum = dqueue.append((v0, next_c_num, next_d_sum, next_d2_sum))answer = 0for i in range(N):answer += total_dist2_sum[i]answer %= MODanswer *= pow(2, MOD - 2, MOD)answer %= MODprint(answer)if __name__ == "__main__":main()