結果
問題 | No.1769 Don't Stop the Game |
ユーザー |
|
提出日時 | 2021-10-27 13:36:51 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,035 ms / 3,000 ms |
コード長 | 2,852 bytes |
コンパイル時間 | 150 ms |
コンパイル使用メモリ | 82,780 KB |
実行使用メモリ | 171,576 KB |
最終ジャッジ日時 | 2024-06-29 17:49:01 |
合計ジャッジ時間 | 19,261 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 28 |
ソースコード
import sysinput = lambda: sys.stdin.readline().rstrip()mi = lambda: map(int,input().split())li = lambda: list(mi())from collections import dequen = int(input())g = [[] for i in range(n)]for i in range(n-1):a,b,c = mi()a,b = a-1,b-1g[a].append((b,c))g[b].append((a,c))siz,x,par = [0]*n,[0]*n,[-1]*nst1 = deque([0])while st1:cur = st1.pop()if siz[cur]:g[cur].sort(key = lambda val: -siz[val[0]])for to,weight in g[cur]:siz[cur] += siz[to]else:st1.append(cur)par_idx = -1for i in range(len(g[cur])):to,weight = g[cur][i]if siz[to]:par_idx = ielse:x[to] = x[cur] ^ weightpar[to] = curst1.append(to)if par_idx != -1:g[cur][-1],g[cur][par_idx] = g[cur][par_idx],g[cur][-1]g[cur].pop()siz[cur] = 1sorted_x = sorted(set(x))for i in range(n):ok,ng = 0,len(sorted_x)while ng-ok > 1:m = (ok+ng) >> 1if sorted_x[m] <= x[i]:ok = melse:ng = mx[i] = oky_sum,cnt_sum = [0]*len(sorted_x),[0]*len(sorted_x)y,cnt = [-1]*n, [-1]*nans = n*(n-1)# st.pop()[1] の0bit目は 0,1 = false/true# 1bit目は 0,1 = 最初,f(true)の直後# を表すst1.append((0,0))st2 = deque([])while st1:cur1,tmp = st1.pop()flag,step = tmp&1,tmp&2if step == 0:st1.append((cur1, flag | 2))for i in range(len(g[cur1])):st1.append((g[cur1][i][0], int(i != 0)))else:for i in range(1,len(g[cur1])):st2.append(g[cur1][i][0])#dfs1-beginwhile st2:cur2 = st2.pop()y_sum[x[cur2]] += y[cur2];cnt_sum[x[cur2]] += cnt[cur2];for to,weight in g[cur2]:st2.append(to)#dfs1-endif y[cur1] == -1:y[cur1] = siz[cur1] - y_sum[x[cur1]]cnt[cur1] = 1 - cnt_sum[x[cur1]]ans -= y_sum[x[cur1]]#case-1y_sum[x[cur1]] += y[cur1]cnt_sum[x[cur1]] += cnt[cur1]if cur1 == 0:for j in range(len(sorted_x)):if j != x[cur1]:ans -= y_sum[j] * cnt_sum[j]#case-3else:ans -= cnt_sum[x[par[cur1]]] * (n - siz[cur1])#case-2ans -= y_sum[x[par[cur1]]] * cnt_sum[x[par[cur1]]]#case-3ans += siz[cur1]#case-3if flag:st2.append(cur1)#dfs1-beginwhile st2:cur2 = st2.pop()y_sum[x[cur2]] -= y[cur2];cnt_sum[x[cur2]] -= cnt[cur2];for to,weight in g[cur2]:st2.append(to)#dfs1-endprint(ans)