結果
| 問題 |
No.3346 Tree to DAG
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-11-22 18:05:36 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 136 ms / 2,000 ms |
| コード長 | 3,528 bytes |
| コンパイル時間 | 2,265 ms |
| コンパイル使用メモリ | 207,360 KB |
| 実行使用メモリ | 35,184 KB |
| 最終ジャッジ日時 | 2025-11-22 18:05:43 |
| 合計ジャッジ時間 | 6,756 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 39 |
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:73:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
73 | scanf("%d%d", &a, &b);
| ~~~~~^~~~~~~~~~~~~~~~
ソースコード
#include <bits/stdc++.h>
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
const int N = 2000010, MOD = 998244353, INF = 0x3f3f3f3f;
int n, m, w[N];
int e[N], ne[N], h[N], idx, d[N], c;
int pa, pb, pc, pd, px, py, pz, pzz;
multiset<int> son[100010];
void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx++; }
ll qmi(ll a, ll b, ll c) { ll res = 1; while (b) { if (b & 1) res = res * a % c; a = a * a % c; b >>= 1; } return res; }
void dfs(int r, int fa) {
for (int i = h[r]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
d[j] = d[r] + 1;
if (d[j] > d[c]) c = j;
dfs(j, r);
}
}
void dfs2(int r, int fa) {
for (int i = h[r]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
dfs2(j, r);
son[r].insert(*son[j].rbegin() + 1);
}
if (!son[r].size()) son[r].insert(0);
}
void dfs3(int r, int fa) {
for (int i = h[r]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa) continue;
if (*son[r].rbegin() == *son[j].rbegin() + 1) {
if (son[r].size() > 1) {
son[j].insert(*(++son[r].rbegin()) + 1);
} else son[j].insert(1);
} else son[j].insert(*son[r].rbegin() + 1);
dfs3(j, r);
}
}
void add(vector<int>& v, int p) {
p = 100010 - p;
while (1) {
v[p]++;
if (v[p] == 2) v[p] = 0, p--;
else break;
}
}
void sub(vector<int>& v, int p) {
p = 100010 - p;
while (1) {
if (v[p] == 1) {
v[p]--;
break;
} else {
v[p--]++;
}
}
}
int main() {
cin >> n;
memset(h, -1, sizeof(int) * (n + 10));
for (int i = 1, a, b; i < n; i++) {
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
dfs(1, 0);
d[c] = 0;
dfs(c, 0);
pa = n - 1 - (d[c] >> 1) + 2, pb = n - 1 - ((d[c] + 1) >> 1) + 2, pc = n - 1 - d[c] + 2, pd = n - 1 - d[c] + 1;
dfs2(c, 0);
dfs3(c, 0);
for (int i = 1; i < n + 1; i++)
if (son[i].size() >= 3) {
auto u = son[i].rbegin();
int x = *(u++), y = *(u++), z = *(u++);
vector<int> v = {x + y, y + z, z + x};
sort(v.begin(), v.end());
if (v[0] >= px) {
if (v[0] > px) px = v[0], py = v[1], pz = v[2], pzz = x + y + z;
else {
if (v[1] >= py) {
if (v[1] > py) px = v[0], py = v[1], pz = v[2], pzz = x + y + z;
else if (v[2] > pz) px = v[0], py = v[1], pz = v[2], pzz = x + y + z;
}
}
}
}
// printf("%d %d %d %d\n", px, py, pz, pzz);
px = n - 1 - px + 2, py = n - 1 - py + 2, pz = n - 1 - pz + 2, pzz = n - 1 - pzz + 1;
assert(px && py && pz && pzz);
assert(pa && pb && pc && pd);
vector<int> a(100010), b(100010);
add(a, pa), add(a, pb), add(a, pc), sub(a, pd), sub(a, pd), sub(a, pd);
add(b, px), add(b, py), add(b, pz), sub(b, pzz), sub(b, pzz), sub(b, pzz);
if (a < b) {
ll res = qmi(2, n + 2, MOD);
res = (res - qmi(2, pa, MOD) - qmi(2, pb, MOD) - qmi(2, pc, MOD) + \
qmi(2, pd, MOD) * 3ll + MOD * 3ll) % MOD;
printf("%lld\n", res);
} else {
ll res = qmi(2, n + 2, MOD);
res = (res - qmi(2, px, MOD) - qmi(2, py, MOD) - qmi(2, pz, MOD) + \
qmi(2, pzz, MOD) * 3ll + MOD * 3ll) % MOD;
printf("%lld\n", res);
}
return 0;
}