結果

問題 No.1333 Squared Sum
ユーザー LyricalMaestroLyricalMaestro
提出日時 2024-11-06 02:20:33
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,767 ms / 2,000 ms
コード長 4,002 bytes
コンパイル時間 260 ms
コンパイル使用メモリ 82,016 KB
実行使用メモリ 206,292 KB
最終ジャッジ日時 2024-11-06 02:21:19
合計ジャッジ時間 45,161 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 40 ms
54,948 KB
testcase_01 AC 42 ms
54,516 KB
testcase_02 AC 42 ms
55,996 KB
testcase_03 AC 1,657 ms
174,500 KB
testcase_04 AC 1,633 ms
174,064 KB
testcase_05 AC 1,596 ms
174,036 KB
testcase_06 AC 1,626 ms
173,948 KB
testcase_07 AC 1,642 ms
173,836 KB
testcase_08 AC 1,658 ms
174,756 KB
testcase_09 AC 1,638 ms
173,872 KB
testcase_10 AC 1,694 ms
173,996 KB
testcase_11 AC 1,623 ms
174,396 KB
testcase_12 AC 1,628 ms
174,840 KB
testcase_13 AC 936 ms
177,408 KB
testcase_14 AC 1,767 ms
175,464 KB
testcase_15 AC 1,754 ms
179,288 KB
testcase_16 AC 43 ms
55,268 KB
testcase_17 AC 42 ms
54,372 KB
testcase_18 AC 42 ms
56,148 KB
testcase_19 AC 42 ms
55,432 KB
testcase_20 AC 43 ms
54,816 KB
testcase_21 AC 42 ms
54,436 KB
testcase_22 AC 42 ms
55,892 KB
testcase_23 AC 42 ms
55,040 KB
testcase_24 AC 43 ms
54,912 KB
testcase_25 AC 42 ms
54,816 KB
testcase_26 AC 1,724 ms
176,364 KB
testcase_27 AC 1,742 ms
177,700 KB
testcase_28 AC 1,760 ms
176,328 KB
testcase_29 AC 956 ms
176,780 KB
testcase_30 AC 673 ms
118,200 KB
testcase_31 AC 428 ms
101,268 KB
testcase_32 AC 995 ms
135,396 KB
testcase_33 AC 815 ms
124,156 KB
testcase_34 AC 1,340 ms
155,892 KB
testcase_35 AC 995 ms
136,388 KB
testcase_36 AC 638 ms
112,608 KB
testcase_37 AC 633 ms
114,544 KB
testcase_38 AC 731 ms
120,084 KB
testcase_39 AC 1,144 ms
143,608 KB
testcase_40 AC 1,111 ms
206,156 KB
testcase_41 AC 1,108 ms
206,028 KB
testcase_42 AC 1,140 ms
206,292 KB
testcase_43 AC 1,102 ms
204,564 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

## https://yukicoder.me/problems/no/1333

from collections import deque

MOD=  10 ** 9 + 7

def main():
    N = int(input())
    next_nodes = [[] for _ in range(N)]
    for _ in range(N - 1):
        u, v, w = map(int, input().split())
        next_nodes[u - 1].append((v - 1, w))
        next_nodes[v - 1].append((u - 1, w))
    
    # 全方位木dp
    parents = [-2] * N
    parents_index = [-2] * N
    parents_index2 = [-2] * N
    total_child_num = [0 for _ in range(N)]
    next_child_num = [[0] * len(next_nodes[v]) for v in range(N)]
    total_dist_sum = [0 for _ in range(N)]
    next_dist_sum = [[0] * len(next_nodes[v]) for v in range(N)]
    total_dist2_sum = [0 for _ in range(N)]
    next_dist2_sum = [[0] * len(next_nodes[v]) for v in range(N)]

    parents[0] = -1
    parents_index[0] = -1
    parents_index2[0] = -1
    stack = deque()
    stack.append((0, 0))
    while len(stack) > 0:
        v, index = stack.pop()

        while index < len(next_nodes[v]):
            w = next_nodes[v][index][0]
            if w == parents[v]:
                parents_index[v] = index
                index += 1
                continue

            parents[w] = v
            parents_index2[w] = index
            stack.append((v, index + 1))
            stack.append((w, 0))
            break

        if index == len(next_nodes[v]):
            p = parents[v]
            if p != -1:
                p_index2 = parents_index2[v]
                total_child_num[p] += 1 + total_child_num[v]
                next_child_num[p][p_index2] = 1 + total_child_num[v]

                w = next_nodes[p][p_index2][1]

                d1 = ((1 + total_child_num[v]) * w) % MOD
                total_dist_sum[p] += (d1 + total_dist_sum[v]) % MOD
                total_dist_sum[p] %= MOD
                next_dist_sum[p][p_index2] = (d1 + total_dist_sum[v]) % MOD

                d2 = ((1 + total_child_num[v]) * w) % MOD
                d2 *= w
                d2 %= MOD
                d2_2 = (w * total_dist_sum[v]) % MOD
                d2_2 *= 2
                d2_2 %= MOD
                d = (total_dist2_sum[v]  +d2_2) % MOD
                d += d2
                d %= MOD
                total_dist2_sum[p] += d
                total_dist2_sum[p] %= MOD
                next_dist2_sum[p][p_index2] = d
    
    # queue Part
    queue = deque()
    queue.append((0, 0, 0, 0))
    while len(queue) > 0:
        v, c_c_num, c_d_sum, c_d2_sum = queue.popleft()
        if parents[v] != -1:
            p = parents[v]
            p_index = parents_index[v]

            total_child_num[v] += c_c_num
            next_child_num[v][p_index] = c_c_num

            total_dist_sum[v] += c_d_sum
            total_dist_sum[v] %= MOD
            next_dist_sum[v][p_index] = c_d_sum

            total_dist2_sum[v] += c_d2_sum
            total_dist2_sum[v] %= MOD
            next_dist2_sum[v][p_index] = c_d2_sum

        
        for i in range(len(next_nodes[v])):
            v0 = next_nodes[v][i][0]
            if v0 == parents[v]:
                continue

            c_num = total_child_num[v] - next_child_num[v][i]
            d_sum = (total_dist_sum[v] - next_dist_sum[v][i]) % MOD
            d2_sum = (total_dist2_sum[v] - next_dist2_sum[v][i]) % MOD

            w = next_nodes[v][i][1]
            next_c_num = c_num + 1

            next_d_sum = ((1 + c_num) * w) % MOD
            next_d_sum += d_sum
            next_d_sum %= MOD

            d2 = ((1 + c_num) * w) % MOD
            d2 *= w
            d2 %= MOD
            d2_2 = (w * d_sum) % MOD
            d2_2 *= 2
            d2_2 %= MOD
            d = (d2_sum + d2_2) % MOD
            d += d2
            d %= MOD
            next_d2_sum = d

            queue.append((v0, next_c_num, next_d_sum, next_d2_sum))

    answer = 0
    for i in range(N):
        answer += total_dist2_sum[i]
        answer %= MOD
    answer *= pow(2, MOD - 2, MOD)
    answer %= MOD
    print(answer)





if __name__ == "__main__":
    main()
0