結果

問題 No.1002 Twotone
ユーザー square1001square1001
提出日時 2020-03-01 20:07:29
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 3,731 bytes
コンパイル時間 1,792 ms
コンパイル使用メモリ 111,296 KB
実行使用メモリ 60,096 KB
最終ジャッジ日時 2024-04-21 22:23:37
合計ジャッジ時間 11,767 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 3 ms
5,248 KB
testcase_01 AC 2 ms
5,376 KB
testcase_02 AC 2 ms
5,376 KB
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 AC 2 ms
5,376 KB
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 AC 3 ms
5,376 KB
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 AC 2 ms
5,376 KB
testcase_19 AC 295 ms
44,608 KB
testcase_20 AC 362 ms
60,096 KB
testcase_21 AC 359 ms
59,128 KB
testcase_22 AC 2 ms
5,376 KB
testcase_23 AC 202 ms
33,996 KB
testcase_24 AC 365 ms
55,280 KB
testcase_25 AC 358 ms
51,972 KB
testcase_26 AC 2 ms
5,376 KB
testcase_27 AC 163 ms
31,488 KB
testcase_28 AC 365 ms
45,232 KB
testcase_29 AC 271 ms
45,236 KB
testcase_30 AC 3 ms
5,376 KB
testcase_31 AC 235 ms
44,828 KB
testcase_32 AC 385 ms
44,852 KB
testcase_33 AC 272 ms
44,856 KB
testcase_34 AC 251 ms
57,108 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <map>
#include <vector>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
class edge {
public:
	int to, col;
	edge() : to(-1), col(-1) {};
	edge(int to_, int col_) : to(to_), col(col_) {};
};
int main() {
	cin.tie(0);
	ios_base::sync_with_stdio(false);
	int N, K;
	cin >> N >> K;
	vector<vector<edge> > G(N);
	for (int i = 0; i < N - 1; ++i) {
		int a, b, col;
		cin >> a >> b >> col; --a, --b, --col;
		G[a].push_back(edge(b, col));
		G[b].push_back(edge(a, col));
	}
	vector<vector<edge> > child(N); vector<edge> par(N);
	vector<int> ord;
	function<void(int, int)> make_order = [&](int pos, int pre) {
		for (edge e : G[pos]) {
			if (e.to == pre) continue;
			par[e.to] = edge(pos, e.col);
			child[pos].push_back(e);
			make_order(e.to, pos);
		}
		ord.push_back(pos);
	};
	make_order(0, -1);
	vector<int> group(N, -1);
	vector<map<int, int> > d(N);
	vector<int> dsum(N);
	long long ans = 0;
	for (int i : ord) {
		sort(child[i].begin(), child[i].end(), [&](edge& e1, edge& e2) {
			return d[group[e1.to]].size() > d[group[e2.to]].size();
		});
		group[i] = (child[i].empty() ? i : group[child[i][0].to]);
		map<pair<int, int>, int> s;
		map<int, int> subs;
		for (int j = 0; j < child[i].size(); ++j) {
			edge e = child[i][j];
			ans += dsum[group[e.to]] - d[group[e.to]][-1];
			if (j >= 1) {
				for (pair<int, int> k : d[group[e.to]]) {
					int va = k.first, vb = e.col;
					if (va > vb) swap(va, vb);
					if (va == vb) va = -1;
					if (va != -1) {
						if (s.find(make_pair(va, vb)) != s.end()) {
							ans += 1LL * s[make_pair(va, vb)] * k.second;
						}
						if (s.find(make_pair(-1, va)) != s.end()) {
							ans += 1LL * s[make_pair(-1, vb)] * k.second;
						}
						if (s.find(make_pair(-1, vb)) != s.end()) {
							ans += 1LL * s[make_pair(-1, vb)] * k.second;
						}
					}
					else {
						if (subs.find(-1) != subs.end()) {
							ans += 1LL * subs[-1] * k.second;
						}
						if (subs.find(vb) != subs.end()) {
							ans += 1LL * subs[vb] * k.second;
						}
						if (s.find(make_pair(-1, vb)) != s.end()) {
							ans -= 2LL * s[make_pair(-1, vb)] * k.second;
						}
					}
				}
				for (pair<int, int> k : d[group[e.to]]) {
					int va = k.first, vb = e.col;
					if (va > vb) swap(va, vb);
					if (va == vb) va = -1;
					s[make_pair(va, vb)] += k.second;
					subs[va] += k.second;
					subs[vb] += k.second;
				}
			}
		}
		int mcol = (child[i].empty() ? -1 : child[i][0].col);
		for (pair<pair<int, int>, int> j : s) {
			int va = j.first.first, vb = j.first.second;
			if (va == mcol) va = -1;
			if (vb == mcol) vb = -1;
			if (va > vb) swap(va, vb);
			if (vb == -1) {
				ans += 1LL * (dsum[group[i]] - d[group[i]][-1]) * j.second;
			}
			else if (va == -1) {
				if (d[group[i]].find(vb) != d[group[i]].end()) {
					ans += 1LL * d[group[i]][vb] * j.second;
				}
				ans += 1LL * d[group[i]][-1] * j.second;
			}
		}
		int pcol = par[i].col;
		if (mcol != pcol && mcol != -1) {
			int val = 0, valp = 0;
			for (pair<int, int> j : d[group[i]]) {
				if (j.first == -1) {
					val = j.second;
				}
				else if (j.first == pcol) {
					valp = j.second;
				}
				else {
					dsum[group[i]] -= j.second;
				}
			}
			d[group[i]].clear();
			d[group[i]][mcol] += val + valp;
		}
		for (int j = 1; j < child[i].size(); ++j) {
			int col = child[i][j].col;
			for (pair<int, int> k : d[child[i][j].to]) {
				int va = col, vb = k.first;
				if (va == pcol) va = -1;
				if (vb == pcol) vb = -1;
				if (va > vb) swap(va, vb);
				if (va == vb) va = -1;
				if (va != -1) continue;
				dsum[group[i]] += k.second;
				d[group[i]][vb] += k.second;
			}
		}
		dsum[group[i]] += 1;
		d[group[i]][-1] += 1;
	}
	cout << ans << endl;
	return 0;
}
0