結果
問題 | No.1333 Squared Sum |
ユーザー |
![]() |
提出日時 | 2021-04-17 23:11:53 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 2,613 bytes |
コンパイル時間 | 282 ms |
コンパイル使用メモリ | 82,600 KB |
実行使用メモリ | 210,852 KB |
最終ジャッジ日時 | 2024-07-04 04:17:42 |
合計ジャッジ時間 | 39,729 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 25 WA * 19 |
ソースコード
class Tree():def __init__(self, n):self.n = nself.tree = [[] for _ in range(n)]self.cost = dict()self.root = Nonedef add_edge(self, u, v, c):if u > v: u, v = v, uself.tree[u].append(v)self.tree[v].append(u)self.cost[u * self.n + v] = cdef get_cost(self, u, v):if u is None or v is None:return 0if u < v:return self.cost[u * self.n + v]else:return self.cost[v * self.n + u]def set_root(self, r=0):self.root = rself.par = [None] * self.nself.ord = [r]stack = [r]while stack:v = stack.pop()for adj in self.tree[v]:if self.par[v] == adj: continueself.par[adj] = vself.ord.append(adj)stack.append(adj)def rerooting(self, op, e, merge, id):if self.root is None: self.set_root()dp = [e] * self.nlt = [id] * self.nrt = [id] * self.ninv = [id] * self.nfor v in self.ord[::-1]:tl = tr = efor adj in self.tree[v]:if self.par[v] == adj: continuelt[adj] = tlself.w = self.get_cost(v, adj)tl = op(tl, dp[adj])for adj in self.tree[v][::-1]:if self.par[v] == adj: continuert[adj] = trself.w = self.get_cost(v, adj)tr = op(tr, dp[adj])dp[v] = trfor v in self.ord:if v == self.root: continuep = self.par[v]pp = self.par[p]self.w = self.get_cost(p, pp)inv[v] = op(merge(lt[v], rt[v]), inv[p])self.w = self.get_cost(v, p)dp[v] = op(dp[v], inv[v])return dpimport sysinput = sys.stdin.buffer.readlineMOD = 1000000007N = int(input())t = Tree(N)for _ in range(N - 1):u, v, w = map(int, input().split())u -= 1; v -= 1t.add_edge(u, v, w)e = (1, 0, 0)def op(p, c):ps, pd, pds = pcs, cd, cds = csize = ps + csw = t.wdist = (pd + cd + w * cs) % MODdsq = (pds + cds + w**2 * cs + 2 * w * cd) % MODreturn size, dist, dsqid = (0, 0, 0)def merge(lt, rt):ls, ld, lds = ltrs, rd, rds = rtsize = ls + rs - 1dist = (ld + rd) % MODdsq = (lds + rds) % MODreturn size, dist, dsqdp = t.rerooting(op, e, merge, id)res = 0for i in range(N):res += dp[i][2]res %= MODprint(res // 2)