結果
問題 | No.1333 Squared Sum |
ユーザー |
![]() |
提出日時 | 2021-04-18 00:17:35 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,192 ms / 2,000 ms |
コード長 | 2,431 bytes |
コンパイル時間 | 142 ms |
コンパイル使用メモリ | 82,304 KB |
実行使用メモリ | 204,816 KB |
最終ジャッジ日時 | 2024-07-04 04:22:11 |
合計ジャッジ時間 | 29,963 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 44 |
ソースコード
class Tree():def __init__(self, n):self.n = nself.tree = [[] for _ in range(n)]self.root = Nonedef add_edge(self, u, v, c):if u > v: u, v = v, uself.tree[u].append((v, c))self.tree[v].append((u, c))def set_root(self, r=0):self.root = rself.par = [None] * self.nself.ord = [r]self.cost = [0] * Nstack = [r]while stack:v = stack.pop()for adj, cost in self.tree[v]:if self.par[v] == adj: continueself.cost[adj] = costself.par[adj] = vself.ord.append(adj)stack.append(adj)def rerooting(self, op, e, merge, id):if self.root is None: self.set_root()dp = [e] * self.nlt = [id] * self.nrt = [id] * self.ninv = [id] * self.nfor v in self.ord[::-1]:tl = tr = efor adj, cost in self.tree[v]:if self.par[v] == adj: continuelt[adj] = tltl = op(tl, dp[adj], cost)for adj, cost in self.tree[v][::-1]:if self.par[v] == adj: continuert[adj] = trtr = op(tr, dp[adj], cost)dp[v] = trfor v in self.ord:if v == self.root: continuep = self.par[v]inv[v] = op(merge(lt[v], rt[v]), inv[p], self.cost[p])dp[v] = op(dp[v], inv[v], self.cost[v])return dpimport io, osinput = io.BytesIO(os.read(0, os.fstat(0).st_size)).readlineMOD = 1000000007N = int(input())t = Tree(N)for _ in range(N - 1):u, v, w = map(int, input().split())u -= 1; v -= 1t.add_edge(u, v, w)def op(p, v, d):p_size, p_dist, p_square = pv_size, v_dist, v_square = vsize = p_size + v_sizedist = (p_dist + v_dist + d * v_size) % MODsquare = (p_square + v_square + d**2 * v_size + 2 * d * v_dist) % MODreturn size, dist, squaredef merge(lt, rt):l_size, l_dist, l_square = ltr_size, r_dist, r_square = rtsize = l_size + r_size - 1dist = (l_dist + r_dist) % MODsquare = (l_square + r_square) % MODreturn size, dist, squaredp = t.rerooting(op, (1, 0, 0), merge, (0, 0, 0))res = 0for size, dist, square in dp:res += squareres %= MODprint(res * pow(2, MOD - 2, MOD) % MOD)