結果

問題 No.1002 Twotone
ユーザー suisensuisen
提出日時 2024-12-22 22:11:34
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 1,054 bytes
コンパイル時間 797 ms
コンパイル使用メモリ 82,320 KB
実行使用メモリ 153,448 KB
最終ジャッジ日時 2024-12-22 22:12:17
合計ジャッジ時間 35,673 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 42 ms
51,840 KB
testcase_01 AC 39 ms
52,352 KB
testcase_02 AC 39 ms
52,352 KB
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 AC 48 ms
54,272 KB
testcase_07 AC 1,000 ms
118,568 KB
testcase_08 AC 1,436 ms
144,776 KB
testcase_09 AC 1,399 ms
143,936 KB
testcase_10 AC 48 ms
55,040 KB
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 AC 53 ms
55,424 KB
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 AC 54 ms
55,332 KB
testcase_27 AC 850 ms
126,680 KB
testcase_28 AC 1,295 ms
153,448 KB
testcase_29 AC 1,150 ms
152,784 KB
testcase_30 AC 47 ms
53,704 KB
testcase_31 AC 1,166 ms
151,888 KB
testcase_32 AC 1,398 ms
152,800 KB
testcase_33 AC 1,248 ms
151,212 KB
testcase_34 AC 929 ms
139,708 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

n, k = map(int, input().split())
g: list[list[tuple[int, int, int]]] = [[] for _ in range(n)]
for eid in range(n - 1):
    u, v, c = map(int, input().split())
    u -= 1
    v -= 1
    g[u].append((c, eid, v))
    g[v].append((c, eid, u))

uf = [-1] * (n - 1)


def size(x: int) -> int:
    return -uf[find(x)]


def find(x: int) -> int:
    if uf[x] < 0:
        return x
    uf[x] = find(uf[x])
    return uf[x]


def merge(x: int, y: int) -> None:
    x = find(x)
    y = find(y)
    if x == y:
        return
    if uf[x] > uf[y]:
        x, y = y, x
    uf[x] += uf[y]
    uf[y] = x


for i in range(n):
    g[i].sort()
    siz = len(g[i])
    for j in range(siz - 1):
        c1, eid1, v1 = g[i][j]
        c2, eid2, v2 = g[i][j + 1]
        if c1 == c2:
            merge(eid1, eid2)

ans = 0
for i in range(n):
    siz = len(g[i])
    cmps: list[int] = []
    for j in range(siz):
        if j == 0 or g[i][j][0] != g[i][j - 1][0]:
            cmps.append(size(g[i][j][1]))
    ans += (sum(cmps) ** 2 - sum(e ** 2 for e in cmps)) // 2
print(ans)
0