結果

問題 No.3451 Same Numbers
コンテスト
ユーザー 👑 potato167
提出日時 2026-01-27 02:40:56
言語 C++17
(gcc 15.2.0 + boost 1.89.0)
結果
WA  
(最新)
AC  
(最初)
実行時間 -
コード長 4,377 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 1,915 ms
コンパイル使用メモリ 218,400 KB
実行使用メモリ 7,976 KB
最終ジャッジ日時 2026-02-20 20:52:25
合計ジャッジ時間 8,250 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 3
other AC * 5 WA * 32
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <bits/stdc++.h>
using namespace std;

static const long long MOD = 998244353;

long long mod_pow(long long a, long long e) {
    long long r = 1 % MOD;
    a %= MOD;
    while (e > 0) {
        if (e & 1) r = (r * a) % MOD;
        a = (a * a) % MOD;
        e >>= 1;
    }
    return r;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N, M;
    cin >> N >> M;

    // inv[1..N] を O(N) で前計算
    vector<long long> inv(N + 1, 0);
    inv[1] = 1;
    for (int i = 2; i <= N; i++) {
        inv[i] = MOD - (MOD / i) * inv[MOD % i] % MOD;
    }

    // 二項係数用(n は最大 M 程度)
    vector<long long> fact(M + 1), invfact(M + 1);
    fact[0] = 1;
    for (int i = 1; i <= M; i++) fact[i] = fact[i - 1] * i % MOD;
    invfact[M] = mod_pow(fact[M], MOD - 2);
    for (int i = M; i >= 1; i--) invfact[i - 1] = invfact[i] * i % MOD;

    auto C = [&](int n, int r) -> long long {
        if (r < 0 || r > n) return 0;
        return fact[n] * invfact[r] % MOD * invfact[n - r] % MOD;
    };

    // 閾値(経験的にこの程度で十分)
    int S = min(N, 450);

    vector<long long> D(M + 1, 0);

    for (int k = 1; k <= N; k++) {
        if (k > M) {
            // 1回目に引いた S は M 回以内に復帰しない
            cout << 1 << "\n";
            continue;
        }

        int B = N - k + 1; // 定常時の箱サイズ
        if (B == 1) {
            // 箱に1個しかない:復帰したら必ず引かれるので、周期 k で確定
            long long ans = 1 + (M - 1) / k;
            ans %= MOD;
            cout << ans << "\n";
            continue;
        }

        if (k <= S) {
            // 小さい k:時間 DP(窓和)
            long long invB = inv[B];

            D[1] = 1;
            for (int t = 2; t <= k; t++) D[t] = 0;

            long long ans = 1;
            if (k == 1) {
                // 以降は毎手独立に 1/N
                ans = (1 + (long long)(M - 1) * invB) % MOD;
                cout << ans << "\n";
                continue;
            }

            long long W = 0; // W = sum_{j=1}^{k-1} D[t-j] を逐次管理
            for (int t = k + 1; t <= M; t++) {
                long long dt = (1 - W) % MOD;
                if (dt < 0) dt += MOD;
                dt = dt * invB % MOD;

                D[t] = dt;
                ans += dt;
                ans %= MOD;

                // W_{t+1} = W_t + D[t] - D[t-k+1]
                W += dt;
                if (W >= MOD) W -= MOD;
                W -= D[t - k + 1];
                if (W < 0) W += MOD;
            }
            cout << ans % MOD << "\n";
        } else {
            // 大きい k:二項分布の裾確率(m は小さい)
            long long p = inv[B];
            long long q = (long long)(B - 1) * inv[B] % MOD; // 1 - p
            long long r = inv[B - 1]; // p/q = 1/(B-1)

            int mmax = (M - 1) / k;
            long long ans = 1;

            // m=1 のとき n = M-1 - (k-1) = M-k, よって q^n
            int n1 = M - k;
            long long qpow = mod_pow(q, n1);

            // 次の m では n が (k-1) 減るので、q^{-(k-1)} を掛けて更新
            // inv(q) = B/(B-1)
            long long invq = (long long)B * inv[B - 1] % MOD;
            long long step = mod_pow(invq, k - 1); // q^{-(k-1)}

            for (int m = 1; m <= mmax; m++) {
                int n = M - 1 - m * (k - 1);
                if (n < m) break; // Bin(n,p) >= m は不可能

                // sum_{j=0}^{m-1} C(n,j) p^j q^{n-j}
                // = q^n * sum_{j=0}^{m-1} C(n,j) (p/q)^j
                long long sum = 0;
                long long rpow = 1;
                for (int j = 0; j <= m - 1; j++) {
                    long long term = C(n, j) * rpow % MOD;
                    sum += term;
                    if (sum >= MOD) sum -= MOD;
                    rpow = rpow * r % MOD;
                }
                sum = sum * qpow % MOD;

                long long prob = (1 - sum) % MOD;
                if (prob < 0) prob += MOD;

                ans += prob;
                if (ans >= MOD) ans -= MOD;

                // 次の m 用に q^n を更新(n が (k-1) 減る)
                qpow = qpow * step % MOD;
            }

            cout << ans % MOD << "\n";
        }
    }

    return 0;
}
0