結果

問題 No.1002 Twotone
ユーザー square1001square1001
提出日時 2020-03-01 21:02:29
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 372 ms / 5,000 ms
コード長 5,574 bytes
コンパイル時間 1,813 ms
コンパイル使用メモリ 146,220 KB
実行使用メモリ 78,780 KB
最終ジャッジ日時 2024-10-13 20:52:41
合計ジャッジ時間 9,658 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,820 KB
testcase_01 AC 1 ms
6,820 KB
testcase_02 AC 2 ms
6,816 KB
testcase_03 AC 262 ms
51,888 KB
testcase_04 AC 331 ms
66,444 KB
testcase_05 AC 339 ms
66,536 KB
testcase_06 AC 2 ms
6,820 KB
testcase_07 AC 139 ms
40,196 KB
testcase_08 AC 238 ms
64,056 KB
testcase_09 AC 239 ms
64,052 KB
testcase_10 AC 2 ms
6,816 KB
testcase_11 AC 209 ms
49,860 KB
testcase_12 AC 265 ms
63,736 KB
testcase_13 AC 274 ms
63,840 KB
testcase_14 AC 2 ms
6,816 KB
testcase_15 AC 143 ms
38,108 KB
testcase_16 AC 264 ms
63,744 KB
testcase_17 AC 292 ms
63,952 KB
testcase_18 AC 2 ms
6,816 KB
testcase_19 AC 195 ms
55,256 KB
testcase_20 AC 245 ms
72,512 KB
testcase_21 AC 225 ms
71,496 KB
testcase_22 AC 2 ms
6,820 KB
testcase_23 AC 125 ms
41,876 KB
testcase_24 AC 238 ms
67,668 KB
testcase_25 AC 243 ms
64,272 KB
testcase_26 AC 2 ms
6,816 KB
testcase_27 AC 168 ms
52,948 KB
testcase_28 AC 354 ms
77,772 KB
testcase_29 AC 283 ms
77,888 KB
testcase_30 AC 2 ms
6,816 KB
testcase_31 AC 275 ms
77,236 KB
testcase_32 AC 372 ms
77,528 KB
testcase_33 AC 276 ms
77,708 KB
testcase_34 AC 201 ms
78,780 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <vector>
#include <iostream>
#include <algorithm>
#include <functional>
#include <unordered_map>
using namespace std;
class edge {
public:
	int to, col;
	edge() : to(-1), col(-1) {};
	edge(int to_, int col_) : to(to_), col(col_) {};
};
inline long long hashing(long long a, long long b) {
	return ((long long)(a) << 32) + b;
}
long long solve(int N, vector<vector<edge> > G) {
	vector<vector<edge> > child(N); vector<edge> par(N);
	vector<int> ord;
	function<void(int, int)> make_order = [&](int pos, int pre) {
		for (edge e : G[pos]) {
			if (e.to == pre) continue;
			par[e.to] = edge(pos, e.col);
			child[pos].push_back(e);
			make_order(e.to, pos);
		}
		ord.push_back(pos);
	};
	make_order(0, -1);
	vector<int> group(N, -1);
	vector<unordered_map<int, int> > d(N);
	vector<int> dsum(N);
	long long ans = 0;
	for (int i : ord) {
		sort(child[i].begin(), child[i].end(), [&](edge& e1, edge& e2) {
			return d[group[e1.to]].size() > d[group[e2.to]].size();
		});
		group[i] = (child[i].empty() ? i : group[child[i][0].to]);
		unordered_map<long long, int> s;
		unordered_map<int, int> subs;
		for (int j = 0; j < child[i].size(); ++j) {
			edge e = child[i][j];
			ans += dsum[group[e.to]] - d[group[e.to]][-1];
			if (j >= 1) {
				for (pair<int, int> k : d[group[e.to]]) {
					int va = k.first, vb = e.col;
					if (va > vb) swap(va, vb);
					if (va == vb) va = -1;
					if (va != -1) {
						if (s.find(hashing(va, vb)) != s.end()) {
							ans += 1LL * s[hashing(va, vb)] * k.second;
						}
						if (s.find(hashing(-1, va)) != s.end()) {
							ans += 1LL * s[hashing(-1, va)] * k.second;
						}
						if (s.find(hashing(-1, vb)) != s.end()) {
							ans += 1LL * s[hashing(-1, vb)] * k.second;
						}
					}
					else {
						if (subs.find(-1) != subs.end()) {
							ans += 1LL * subs[-1] * k.second;
						}
						if (subs.find(vb) != subs.end()) {
							ans += 1LL * subs[vb] * k.second;
						}
						if (s.find(hashing(-1, vb)) != s.end()) {
							ans -= 2LL * s[hashing(-1, vb)] * k.second;
						}
					}
				}
				for (pair<int, int> k : d[group[e.to]]) {
					int va = k.first, vb = e.col;
					if (va > vb) swap(va, vb);
					if (va == vb) va = -1;
					s[hashing(va, vb)] += k.second;
					subs[va] += k.second;
					subs[vb] += k.second;
				}
			}
		}
		int mcol = (child[i].empty() ? -1 : child[i][0].col);
		for (pair<long long, int> j : s) {
			int va = (j.first >> 32), vb = j.first - ((long long)(va) << 32);
			if (va == mcol) va = -1;
			if (vb == mcol) vb = -1;
			if (va > vb) swap(va, vb);
			if (vb == -1) {
				ans += 1LL * (dsum[group[i]] - d[group[i]][-1]) * j.second;
			}
			else if (va == -1) {
				if (d[group[i]].find(vb) != d[group[i]].end()) {
					ans += 1LL * d[group[i]][vb] * j.second;
				}
				ans += 1LL * d[group[i]][-1] * j.second;
			}
		}
		int pcol = par[i].col;
		if (mcol != pcol && mcol != -1) {
			int val = 0, valp = 0;
			for (pair<int, int> j : d[group[i]]) {
				if (j.first == -1) {
					val = j.second;
				}
				else if (j.first == pcol) {
					valp = j.second;
				}
				else {
					dsum[group[i]] -= j.second;
				}
			}
			d[group[i]].clear();
			d[group[i]][mcol] += val + valp;
		}
		for (int j = 1; j < child[i].size(); ++j) {
			int col = child[i][j].col;
			for (pair<int, int> k : d[group[child[i][j].to]]) {
				int va = col, vb = k.first;
				if (va == pcol) va = -1;
				if (vb == pcol) vb = -1;
				if (va > vb) swap(va, vb);
				if (va == vb) va = -1;
				if (va != -1) continue;
				dsum[group[i]] += k.second;
				d[group[i]][vb] += k.second;
			}
		}
		dsum[group[i]] += 1;
		d[group[i]][-1] += 1;
	}
	return ans;
}
#include <set>
long long solve_easy(int N, vector<vector<edge> > G) {
	long long ans = 0;
	for (int i = 0; i < N; ++i) {
		for (int j = 0; j < i; ++j) {
			set<int> s;
			function<bool(int, int)> dfs = [&](int pos, int pre) {
				if (pos == i) return true;
				bool ok = false;
				for (edge e : G[pos]) {
					if (e.to == pre) continue;
					bool res = dfs(e.to, pos);
					if (res) {
						s.insert(e.col);
						ok = true;
					}
				}
				return ok;
			};
			dfs(j, -1);
			if (s.size() == 2) {
				++ans;
			}
		}
	}
	return ans;
}
#include <random>
#include <string>
mt19937 mt(2003012025);
int rand_rng(int l, int r) {
	uniform_int_distribution<int> p(l, r - 1);
	return p(mt);
}
string to_string(vector<int> arr) {
	int n = arr.size();
	string ans = "[";
	for (int i = 0; i < n; ++i) {
		if (i) ans += ", ";
		ans += to_string(arr[i]);
	}
	ans += "]";
	return ans;
}
void random_gen() {
	const int samples = 10000;
	int N = 7, K = 3;
	for (int i = 1; i <= samples; ++i) {
		vector<vector<edge> > G(N);
		vector<int> parseq(N, -1), colseq(N, -1);
		for (int j = 1; j < N; ++j) {
			int par = rand_rng(0, j), col = rand_rng(0, K);
			G[j].push_back(edge(par, col));
			G[par].push_back(edge(j, col));
			parseq[j] = par;
			colseq[j] = col;
		}
		long long res1 = solve(N, G);
		long long res2 = solve_easy(N, G);
		if (res1 != res2) {
			cout << "Case #" << i << " / N = " << N << endl;
			cout << "par = " << to_string(parseq) << endl;
			cout << "col = " << to_string(colseq) << endl;
			cout << "Returns: " << res1 << endl;
			cout << "Answer: " << res2 << endl;
			cout << endl;
		}
	}
}
int main() {
	// random_gen();
	cin.tie(0);
	ios_base::sync_with_stdio(false);
	int N, K;
	cin >> N >> K;
	vector<vector<edge> > G(N);
	for (int i = 0; i < N - 1; ++i) {
		int a, b, col;
		cin >> a >> b >> col; --a, --b, --col;
		G[a].push_back(edge(b, col));
		G[b].push_back(edge(a, col));
	}
	long long ans = solve(N, G);
	cout << ans << endl;
	return 0;
}
0