結果

問題 No.1002 Twotone
ユーザー square1001square1001
提出日時 2020-03-01 20:23:08
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 3,738 bytes
コンパイル時間 1,533 ms
コンパイル使用メモリ 113,448 KB
実行使用メモリ 60,104 KB
最終ジャッジ日時 2024-10-13 20:50:10
合計ジャッジ時間 8,312 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 1 ms
5,248 KB
testcase_03 AC 215 ms
35,732 KB
testcase_04 WA -
testcase_05 AC 288 ms
45,308 KB
testcase_06 AC 2 ms
5,248 KB
testcase_07 AC 110 ms
26,716 KB
testcase_08 AC 191 ms
41,836 KB
testcase_09 AC 186 ms
41,708 KB
testcase_10 AC 2 ms
5,248 KB
testcase_11 AC 167 ms
33,164 KB
testcase_12 AC 223 ms
42,036 KB
testcase_13 AC 230 ms
41,976 KB
testcase_14 AC 2 ms
5,248 KB
testcase_15 AC 122 ms
25,476 KB
testcase_16 AC 225 ms
42,392 KB
testcase_17 AC 228 ms
42,228 KB
testcase_18 AC 2 ms
5,248 KB
testcase_19 AC 161 ms
44,472 KB
testcase_20 AC 206 ms
60,104 KB
testcase_21 AC 207 ms
59,112 KB
testcase_22 AC 2 ms
5,248 KB
testcase_23 AC 150 ms
34,124 KB
testcase_24 AC 212 ms
55,196 KB
testcase_25 AC 207 ms
51,852 KB
testcase_26 AC 2 ms
5,248 KB
testcase_27 AC 116 ms
31,236 KB
testcase_28 AC 224 ms
45,164 KB
testcase_29 AC 186 ms
45,240 KB
testcase_30 AC 2 ms
5,248 KB
testcase_31 AC 159 ms
44,720 KB
testcase_32 AC 204 ms
44,792 KB
testcase_33 AC 175 ms
44,856 KB
testcase_34 AC 221 ms
57,104 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <map>
#include <vector>
#include <iostream>
#include <algorithm>
#include <functional>
using namespace std;
class edge {
public:
	int to, col;
	edge() : to(-1), col(-1) {};
	edge(int to_, int col_) : to(to_), col(col_) {};
};
int main() {
	cin.tie(0);
	ios_base::sync_with_stdio(false);
	int N, K;
	cin >> N >> K;
	vector<vector<edge> > G(N);
	for (int i = 0; i < N - 1; ++i) {
		int a, b, col;
		cin >> a >> b >> col; --a, --b, --col;
		G[a].push_back(edge(b, col));
		G[b].push_back(edge(a, col));
	}
	vector<vector<edge> > child(N); vector<edge> par(N);
	vector<int> ord;
	function<void(int, int)> make_order = [&](int pos, int pre) {
		for (edge e : G[pos]) {
			if (e.to == pre) continue;
			par[e.to] = edge(pos, e.col);
			child[pos].push_back(e);
			make_order(e.to, pos);
		}
		ord.push_back(pos);
	};
	make_order(0, -1);
	vector<int> group(N, -1);
	vector<map<int, int> > d(N);
	vector<int> dsum(N);
	long long ans = 0;
	for (int i : ord) {
		sort(child[i].begin(), child[i].end(), [&](edge& e1, edge& e2) {
			return d[group[e1.to]].size() > d[group[e2.to]].size();
		});
		group[i] = (child[i].empty() ? i : group[child[i][0].to]);
		map<pair<int, int>, int> s;
		map<int, int> subs;
		for (int j = 0; j < child[i].size(); ++j) {
			edge e = child[i][j];
			ans += dsum[group[e.to]] - d[group[e.to]][-1];
			if (j >= 1) {
				for (pair<int, int> k : d[group[e.to]]) {
					int va = k.first, vb = e.col;
					if (va > vb) swap(va, vb);
					if (va == vb) va = -1;
					if (va != -1) {
						if (s.find(make_pair(va, vb)) != s.end()) {
							ans += 1LL * s[make_pair(va, vb)] * k.second;
						}
						if (s.find(make_pair(-1, va)) != s.end()) {
							ans += 1LL * s[make_pair(-1, vb)] * k.second;
						}
						if (s.find(make_pair(-1, vb)) != s.end()) {
							ans += 1LL * s[make_pair(-1, vb)] * k.second;
						}
					}
					else {
						if (subs.find(-1) != subs.end()) {
							ans += 1LL * subs[-1] * k.second;
						}
						if (subs.find(vb) != subs.end()) {
							ans += 1LL * subs[vb] * k.second;
						}
						if (s.find(make_pair(-1, vb)) != s.end()) {
							ans -= 2LL * s[make_pair(-1, vb)] * k.second;
						}
					}
				}
				for (pair<int, int> k : d[group[e.to]]) {
					int va = k.first, vb = e.col;
					if (va > vb) swap(va, vb);
					if (va == vb) va = -1;
					s[make_pair(va, vb)] += k.second;
					subs[va] += k.second;
					subs[vb] += k.second;
				}
			}
		}
		int mcol = (child[i].empty() ? -1 : child[i][0].col);
		for (pair<pair<int, int>, int> j : s) {
			int va = j.first.first, vb = j.first.second;
			if (va == mcol) va = -1;
			if (vb == mcol) vb = -1;
			if (va > vb) swap(va, vb);
			if (vb == -1) {
				ans += 1LL * (dsum[group[i]] - d[group[i]][-1]) * j.second;
			}
			else if (va == -1) {
				if (d[group[i]].find(vb) != d[group[i]].end()) {
					ans += 1LL * d[group[i]][vb] * j.second;
				}
				ans += 1LL * d[group[i]][-1] * j.second;
			}
		}
		int pcol = par[i].col;
		if (mcol != pcol && mcol != -1) {
			int val = 0, valp = 0;
			for (pair<int, int> j : d[group[i]]) {
				if (j.first == -1) {
					val = j.second;
				}
				else if (j.first == pcol) {
					valp = j.second;
				}
				else {
					dsum[group[i]] -= j.second;
				}
			}
			d[group[i]].clear();
			d[group[i]][mcol] += val + valp;
		}
		for (int j = 1; j < child[i].size(); ++j) {
			int col = child[i][j].col;
			for (pair<int, int> k : d[group[child[i][j].to]]) {
				int va = col, vb = k.first;
				if (va == pcol) va = -1;
				if (vb == pcol) vb = -1;
				if (va > vb) swap(va, vb);
				if (va == vb) va = -1;
				if (va != -1) continue;
				dsum[group[i]] += k.second;
				d[group[i]][vb] += k.second;
			}
		}
		dsum[group[i]] += 1;
		d[group[i]][-1] += 1;
	}
	cout << ans << endl;
	return 0;
}
0