結果

問題 No.2801 Unique Maximum
ユーザー 👑 potato167potato167
提出日時 2024-06-28 20:15:51
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 1,642 ms / 4,000 ms
コード長 4,383 bytes
コンパイル時間 4,061 ms
コンパイル使用メモリ 251,664 KB
実行使用メモリ 27,136 KB
最終ジャッジ日時 2024-06-28 20:26:15
合計ジャッジ時間 22,276 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 4 ms
5,376 KB
testcase_02 AC 422 ms
9,416 KB
testcase_03 AC 2 ms
5,376 KB
testcase_04 AC 4 ms
5,376 KB
testcase_05 AC 6 ms
5,376 KB
testcase_06 AC 7 ms
5,376 KB
testcase_07 AC 1,624 ms
27,136 KB
testcase_08 AC 1,588 ms
26,980 KB
testcase_09 AC 2 ms
5,376 KB
testcase_10 AC 1,558 ms
25,300 KB
testcase_11 AC 1,642 ms
24,484 KB
testcase_12 AC 1,564 ms
23,616 KB
testcase_13 AC 1,569 ms
24,612 KB
testcase_14 AC 1,598 ms
25,304 KB
testcase_15 AC 4 ms
5,376 KB
testcase_16 AC 3 ms
5,376 KB
testcase_17 AC 5 ms
5,376 KB
testcase_18 AC 615 ms
13,348 KB
testcase_19 AC 662 ms
14,400 KB
testcase_20 AC 721 ms
14,996 KB
testcase_21 AC 1,538 ms
24,316 KB
testcase_22 AC 1,321 ms
25,644 KB
testcase_23 AC 1,544 ms
24,404 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using ll = long long;
#define rep(i, a, b) for (ll i = (ll)(a); i < (ll)(b); i++)
#include <atcoder/convolution>
using mint = atcoder::modint998244353;
using namespace std;

vector<vector<mint>> solve_naive(ll N, ll M){
    vector<vector<mint>> res(N, vector<mint>(M));
    rep(i, 0, M) res[0][i] = 1;
    rep(i, 1, N) rep(j, 1, M){
        res[i][j] = res[i][j - 1];
        rep(k, 0, i) res[i][j] += res[k][j - 1] * res[i - 1 - k][j - 1];
    }
    return res;
}
vector<mint> solve(ll N, ll M){
    // f += g;
    auto add = [&](vector<mint> &f, vector<mint> g, mint v = 1) -> void {
        rep(i, 0, g.size()){
            if (i == (int)f.size()){
                f.push_back(0);
            }
            f[i] += g[i] * v;
        }
    };
    // Binomial
    vector<mint> fact(N + 1, 1), fact_inv(N + 1, 1);
    rep(i, 1, N + 1) fact[i] = fact[i - 1] * i;
    fact_inv[N] = fact[N].inv();
    for (int i = N; i > 0; i--){
        fact_inv[i - 1] = fact_inv[i] * i;
    }
    auto C = [&](int a, int b) -> mint {
        if (a < b || b < 0) return 0;
        return fact[a] * fact_inv[a - b] * fact_inv[b];
    };

    // init
    vector<mint> f(N + 2);
    vector<mint> A(N + 2); // f + f ^ 2
    vector<mint> B(N + 2); // f(x + x ^ 2)
    // g(x) = x + x ^ 2
    vector<mint> g = {0, 1, 1};
    // return (g(x) ^ (r - l - 1), sum_{i = l , l + 1, ... r - "2"}(g(x) ^ {i - l} f_{i})) 
    auto calc = [&](auto self, int l, int r) -> pair<vector<mint>, vector<mint>> {
        if (l + 1 == r) return {{1}, {0}};
        int m = (l + r) / 2;
        pair<vector<mint>, vector<mint>> tmp1 = self(self, l, m);
        // add A[m : r] from f[l : m - 1]
        vector<mint> p(m - l);
        vector<mint> q(min(l, r - l));
        rep(i, 0, min(l, r - l)) q[i] = f[i];
        rep(i, l, m - 1) p[i - l] = f[i];
        auto v = atcoder::convolution(p, q);
        rep(i, m, r){
            int j = i - l;
            if (0 <= j && j < (int)v.size()) A[i] += v[j] * 2;
        }
        v = atcoder::convolution(p, p);
        rep(i, m, r){
            int j = i - l * 2;
            if (0 <= j && j < (int)v.size()) A[i] += v[j];
        }
        // add B[m : r] from f[l : m - 1] = sum * (x + x ^ 2) ^ l
        // h(x) = g ^ l
        // h'(x) = h[m - (|sum| - 1): r]
        int a = max(0, m - (int)(tmp1.second.size()) + 1);
        int b = r;
        int len = b - a;
        vector<mint> h(len);
        rep(i, a, b) h[i - a] = C(l, i - l);
        p = atcoder::convolution(h, tmp1.second);
        rep(i, m, r){
            int j = i - a;
            if (0 <= j && j < (int)p.size()){
                B[i] += p[j];
            }
        }
        // calc f[m - 1]
        if (m <= 3){
            if (m == 1){
                f[m - 1] = 0;
            }
            if (m == 2){
                f[m - 1] = 1;
            }
            if (m == 3){
                f[m - 1] = M;
            }
        }
        else{
            f[m - 1] = (A[m] - B[m]) / (m - 3);
        }
        // add A[m : r] from f[m - 1]
        A[m - 1] += f[m - 1];
        rep(i, m, r){
            if (i < 2 * (m - 1)) A[i] += f[m - 1] * f[i - m + 1] * 2ll;
            if (i == 2 * (m - 1)) A[i] += f[m - 1] * f[m - 1];
        }
        // add B[m : r] from f[m - 1]
        rep(i, m - 1, r){
            B[i] += f[m - 1] * C(m - 1, i - m + 1);
        }
        // add sum from f[m - 1]
        add(tmp1.second, tmp1.first, f[m - 1]);
        pair<vector<mint>, vector<mint>> tmp2 = self(self, m, r);
        // merge
        tmp1.first = atcoder::convolution(tmp1.first, g);
        add(tmp1.second, atcoder::convolution(tmp2.second, tmp1.first));
        tmp1.first = atcoder::convolution(tmp1.first, tmp2.first);
        return tmp1;
    };
    calc(calc, 0, N + 2);
    f.erase(f.begin());
    f.pop_back();
    return f;
}

void test(){
    ll N = 500, M = 500;
    auto ans1 = solve_naive(N, M);
    rep(j, 1, M){
        auto ans2 = solve(N, j);
        rep(i, 0, N){
            if (ans2[i] != ans1[i][j]){
                cout << "No" << endl;
                cout << ans1[i][j].val() << " " << ans2[i].val() << "\n";
                cout << i << " " << j << "\n";
                assert(false);
            }
        }
    }
    cout << "Yes\n";
}

int main(){
    // test();
    ll N, M;
    cin >> N >> M;
    cout << solve(N + 1, M).back().val() << "\n";
}
0