結果

問題 No.3443 Sum of (Tree Distances)^K 1
コンテスト
ユーザー 👑 potato167
提出日時 2025-12-26 01:53:25
言語 C++17
(gcc 15.2.0 + boost 1.89.0)
結果
TLE  
実行時間 -
コード長 3,159 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 6,730 ms
コンパイル使用メモリ 217,916 KB
実行使用メモリ 17,316 KB
最終ジャッジ日時 2026-02-06 20:50:17
合計ジャッジ時間 9,824 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 26 TLE * 1 -- * 20
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <bits/stdc++.h>
using namespace std;

static const long long MOD = 998244353LL;

long long mod_pow(long long a, long long e) {
    a %= MOD;
    if (a < 0) a += MOD;
    long long r = 1;
    while (e > 0) {
        if (e & 1) r = (r * a) % MOD;
        a = (a * a) % MOD;
        e >>= 1;
    }
    return r;
}
long long mod_inv(long long a) { return mod_pow(a, MOD - 2); }

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N, K;
    cin >> N >> K;

    const long long inv2 = (MOD + 1) / 2;

    // factorials for combinations
    vector<long long> fact(N + 1), invfact(N + 1);
    fact[0] = 1;
    for (int i = 1; i <= N; i++) fact[i] = fact[i - 1] * i % MOD;
    invfact[N] = mod_inv(fact[N]);
    for (int i = N; i >= 1; i--) invfact[i - 1] = invfact[i] * i % MOD;

    auto C = [&](int n, int r) -> long long {
        if (r < 0 || r > n) return 0;
        return fact[n] * invfact[r] % MOD * invfact[n - r] % MOD;
    };

    // t^K for t=0..N (we use t>=1)
    vector<long long> tPowK(N + 1, 0);
    for (int t = 1; t <= N; t++) tPowK[t] = mod_pow(t, K);

    // W[s] = sum over unordered pairs (u<v) in s-vertex Cayley trees whose path contains root, dist^K
    // W_s = sum_{t=1}^{s-1} A_{s,t} * t^K
    // A_{s,t} = ((s-1)!/(s-t-1)!) * (t+1)^2 * s^{s-t-2} / 2
    vector<long long> W(N + 1, 0);
    W[1] = 0;
    for (int s = 2; s <= N; s++) {
        long long inv_s = mod_inv(s);
        long long fs1 = fact[s - 1];
        for (int t = 1; t <= s - 1; t++) {
            long long fall = fs1 * invfact[s - t - 1] % MOD; // (s-1)!/(s-t-1)!
            long long tp1 = (t + 1) % MOD;
            long long tp1sq = tp1 * tp1 % MOD;

            int exp = s - t - 2; // can be -1 only when t=s-1
            long long pow_s = (exp == -1 ? inv_s : mod_pow(s, exp));

            long long A = fall * tp1sq % MOD * pow_s % MOD * inv2 % MOD;
            W[s] = (W[s] + A * tPowK[t]) % MOD;
        }
    }

    // N^e
    vector<long long> powN(N + 1, 1);
    for (int e = 1; e <= N; e++) powN[e] = powN[e - 1] * N % MOD;

    // Answer for each a
    // if a==N: Ans[N] = W[N]
    // else (m=N-a>=1):
    // Ans[a] = N^{N-a-1} * sum_{s=1..a} C(a-1,s-1) * (s*W[s]) * F(a-s,m)
    // where F(0,m)=1, and for n>=1: F(n,m)= m * (m+n)^{n-1} = m*(N-s)^{n-1}
    for (int a = 1; a <= N; a++) {
        long long ans = 0;
        if (a == N) {
            ans = W[N];
        } else {
            int m = N - a;                // >=1
            long long outer = powN[N - a - 1];
            long long sum = 0;

            for (int s = 1; s <= a; s++) {
                long long comb = C(a - 1, s - 1);
                long long term = comb * (long long)s % MOD * W[s] % MOD;

                int n = a - s;
                long long F = 1;
                if (n >= 1) {
                    // m+n = N-s
                    F = (long long)m % MOD * mod_pow(N - s, n - 1) % MOD;
                }
                term = term * F % MOD;
                sum += term;
                if (sum >= MOD) sum -= MOD;
            }
            ans = outer * sum % MOD;
        }
        cout << ans << "\n";
    }
    return 0;
}
0