結果

問題 No.3346 Tree to DAG
コンテスト
ユーザー The Forsaking
提出日時 2025-11-22 18:05:36
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 136 ms / 2,000 ms
コード長 3,528 bytes
コンパイル時間 2,265 ms
コンパイル使用メモリ 207,360 KB
実行使用メモリ 35,184 KB
最終ジャッジ日時 2025-11-22 18:05:43
合計ジャッジ時間 6,756 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 39
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:73:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   73 |         scanf("%d%d", &a, &b);
      |         ~~~~~^~~~~~~~~~~~~~~~

ソースコード

diff #
raw source code

#include <bits/stdc++.h>

using namespace std;

typedef pair<int, int> pii;
typedef long long ll;
const int N = 2000010, MOD = 998244353, INF = 0x3f3f3f3f;
int n, m, w[N];

int e[N], ne[N], h[N], idx, d[N], c;
int pa, pb, pc, pd, px, py, pz, pzz;
multiset<int> son[100010];
void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx++; }
ll qmi(ll a, ll b, ll c) { ll res = 1; while (b) { if (b & 1) res = res * a % c; a = a * a % c; b >>= 1; } return res; }

void dfs(int r, int fa) {
    for (int i = h[r]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        d[j] = d[r] + 1;
        if (d[j] > d[c]) c = j;
        dfs(j, r);
    }
}

void dfs2(int r, int fa) {
    for (int i = h[r]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        dfs2(j, r);
        son[r].insert(*son[j].rbegin() + 1);
    }
    if (!son[r].size()) son[r].insert(0);
}

void dfs3(int r, int fa) {
    for (int i = h[r]; ~i; i = ne[i]) {
        int j = e[i];
        if (j == fa) continue;
        if (*son[r].rbegin() == *son[j].rbegin() + 1) {
            if (son[r].size() > 1) {
                son[j].insert(*(++son[r].rbegin()) + 1);
            } else son[j].insert(1);
        } else son[j].insert(*son[r].rbegin() + 1);
        dfs3(j, r);
    }
}

void add(vector<int>& v, int p) {
    p = 100010 - p;
    while (1) {
        v[p]++;
        if (v[p] == 2) v[p] = 0, p--;
        else break;
    }
}
void sub(vector<int>& v, int p) {
    p = 100010 - p;
    while (1) {
        if (v[p] == 1) {
            v[p]--;
            break;
        } else {
            v[p--]++;
        }
    }
}

int main() {
    cin >> n;
    memset(h, -1, sizeof(int) * (n + 10));
    for (int i = 1, a, b; i < n; i++) {
        scanf("%d%d", &a, &b);
        add(a, b), add(b, a);
    }

    dfs(1, 0);
    d[c] = 0;
    dfs(c, 0);
    pa = n - 1 - (d[c] >> 1) + 2, pb = n - 1 - ((d[c] + 1) >> 1) + 2, pc = n - 1 - d[c] + 2, pd = n - 1 - d[c] + 1;

    dfs2(c, 0);
    dfs3(c, 0);
    for (int i = 1; i < n + 1; i++)
        if (son[i].size() >= 3) {
            auto u = son[i].rbegin();
            int x = *(u++), y = *(u++), z = *(u++);
            vector<int> v = {x + y, y + z, z + x};
            sort(v.begin(), v.end());
            if (v[0] >= px) {
                if (v[0] > px) px = v[0], py = v[1], pz = v[2], pzz = x + y + z;
                else {
                    if (v[1] >= py) {
                        if (v[1] > py) px = v[0], py = v[1], pz = v[2], pzz = x + y + z;
                        else if (v[2] > pz) px = v[0], py = v[1], pz = v[2], pzz = x + y + z;
                    }
                }
            }
        }
    // printf("%d %d %d %d\n", px, py, pz, pzz);
    px = n - 1 - px + 2, py = n - 1 - py + 2, pz = n - 1 - pz + 2, pzz = n - 1 - pzz + 1;
    assert(px && py && pz && pzz);
    assert(pa && pb && pc && pd);

    vector<int> a(100010), b(100010);
    add(a, pa), add(a, pb), add(a, pc), sub(a, pd), sub(a, pd), sub(a, pd);
    add(b, px), add(b, py), add(b, pz), sub(b, pzz), sub(b, pzz), sub(b, pzz);
    if (a < b) {
        ll res = qmi(2, n + 2, MOD);
        res = (res - qmi(2, pa, MOD) - qmi(2, pb, MOD) - qmi(2, pc, MOD) + \
            qmi(2, pd, MOD) * 3ll + MOD * 3ll) % MOD;
        printf("%lld\n", res);
    } else {
        ll res = qmi(2, n + 2, MOD);
        res = (res - qmi(2, px, MOD) - qmi(2, py, MOD) - qmi(2, pz, MOD) + \
            qmi(2, pzz, MOD) * 3ll + MOD * 3ll) % MOD;
        printf("%lld\n", res);
    }
    return 0;
}
0