結果

問題 No.2459 Stampaholic (Hard)
ユーザー srjywrdnprktsrjywrdnprkt
提出日時 2023-11-30 13:36:09
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 393 ms / 4,000 ms
コード長 6,483 bytes
コンパイル時間 3,737 ms
コンパイル使用メモリ 249,340 KB
実行使用メモリ 35,040 KB
最終ジャッジ日時 2024-09-26 14:03:13
合計ジャッジ時間 9,299 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 13 ms
11,032 KB
testcase_01 AC 393 ms
34,896 KB
testcase_02 AC 99 ms
15,876 KB
testcase_03 AC 15 ms
11,136 KB
testcase_04 AC 14 ms
11,136 KB
testcase_05 AC 14 ms
11,008 KB
testcase_06 AC 14 ms
11,008 KB
testcase_07 AC 15 ms
11,188 KB
testcase_08 AC 192 ms
20,436 KB
testcase_09 AC 99 ms
16,480 KB
testcase_10 AC 385 ms
32,040 KB
testcase_11 AC 196 ms
22,520 KB
testcase_12 AC 375 ms
34,356 KB
testcase_13 AC 374 ms
31,928 KB
testcase_14 AC 98 ms
17,196 KB
testcase_15 AC 379 ms
35,040 KB
testcase_16 AC 381 ms
34,892 KB
testcase_17 AC 380 ms
34,920 KB
testcase_18 AC 378 ms
35,016 KB
testcase_19 AC 384 ms
34,896 KB
testcase_20 AC 14 ms
11,232 KB
testcase_21 AC 370 ms
29,940 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/convolution>

using namespace std;
using namespace atcoder;
using ll = long long;

using mint = modint998244353;

const ll MAX = 1e6+10;
vector<mint> f, finv;

mint inv(mint x){
    mint ans = 1;
    ll e = 998244351;

    while (e > 0){
        if ((e & 1LL)) ans *= x;
        e = e >> 1LL;
        x *= x;
    }

    return ans;
}

void init(){
    f.resize(MAX+1); finv.resize(MAX+1);
    f[0] = 1;
    for (int i=1; i<=MAX; i++) f[i] = f[i-1]*i;
    finv[MAX] = inv(f[MAX]);
    for (int i=MAX-1; i>=0; i--) finv[i] = finv[i+1] * (i+1);
}

mint C(ll n, ll k){
    if (n < k || k < 0) return 0;

    return f[n] * finv[k] * finv[n-k] ;
}

template<class T>
struct FormalPowerSeries : vector<T> {
    using vector<T>::vector;
    using vector<T>::operator=;
    using F = FormalPowerSeries;

    F operator-() const {
        F res(*this);
        for (auto &e : res) e = -e;
        return res;
    }
    F &operator*=(const T &g) {
        for (auto &e : *this) e *= g;
        return *this;
    }
    F &operator/=(const T &g) {
        assert(g != T(0));
        *this *= g.inv();
        return *this;
    }
    F &operator+=(const F &g) {
        int n = (*this).size(), m = g.size();
        for (int i=0; i<min(n, m); i++) (*this)[i] += g[i];
        return *this;
    }
    F &operator-=(const F &g) {
        int n = (*this).size(), m = g.size();
        for (int i=0; i<min(n, m); i++) (*this)[i] -= g[i];
        return *this;
    }
    F &operator<<=(const int d) {
        int n = (*this).size();
        (*this).insert((*this).begin(), d, 0);
        (*this).resize(n);
        return *this;
    }
    F &operator>>=(const int d) {
        int n = (*this).size();
        (*this).erase((*this).begin(), (*this).begin() + min(n, d));
        (*this).resize(n);
        return *this;
    }
    F inv(int d = -1) const {
        int n = (*this).size();
        assert(n != 0 && (*this)[0] != 0);
        if (d == -1) d = n;
        assert(d > 0);
        F res{(*this)[0].inv()};
        while (res.size() < d) {
            int m = size(res);
            F f(begin(*this), begin(*this) + min(n, 2*m));
            F r(res);
            f.resize(2*m), internal::butterfly(f);
            r.resize(2*m), internal::butterfly(r);
            for (int i=0; i<2*m; i++) f[i] *= r[i];
            internal::butterfly_inv(f);
            f.erase(f.begin(), f.begin() + m);
            f.resize(2*m), internal::butterfly(f);
            for (int i=0; i<2*m; i++) f[i] *= r[i];
            internal::butterfly_inv(f);
            T iz = T(2*m).inv(); iz *= -iz;
            for (int i=0; i<m; i++) f[i] *= iz;
            res.insert(res.end(), f.begin(), f.begin() + m);
        }
        return {res.begin(), res.begin() + d};
    }

    // fast: FMT-friendly modulus only
    F &operator*=(const F &g) {
        int n = (*this).size();
        *this = convolution(*this, g);
        (*this).resize(n);
        return *this;
    }
    F &operator/=(const F &g) {
        int n = (*this).size();
        *this = convolution(*this, g.inv(n));
        (*this).resize(n);
        return *this;
    }

    // sparse
    F &operator*=(vector<pair<int, T>> g) {
        int n = (*this).size();
        auto [d, c] = g.front();
        if (d == 0) g.erase(g.begin());
        else c = 0;
        for (int i=n-1; i>=0; i--){
            (*this)[i] *= c;
            for (auto &[j, b] : g) {
                if (j > i) break;
                (*this)[i] += (*this)[i-j] * b;
            }
        }
        return *this;
    }
    F &operator/=(vector<pair<int, T>> g) {
        int n = (*this).size();
        auto [d, c] = g.front();
        assert(d == 0 && c != T(0));
        T ic = c.inv();
        g.erase(g.begin());
        for (int i=0; i<n; i++){
            for (auto &[j, b] : g) {
            if (j > i) break;
            (*this)[i] -= (*this)[i-j] * b;
            }
            (*this)[i] *= ic;
        }
        return *this;
    }

    // multiply and divide (1 + cz^d)
    void multiply(const int d, const T c) {
        int n = (*this).size();
        if (c == T(1)) for (int i=n-d-1; i>=0; i--) (*this)[i+d] += (*this)[i];
        else if (c == T(-1)) for (int i=n-d-1; i>=0; i--) (*this)[i+d] -= (*this)[i];
        else for (int i=n-d-1; i>=0; i--) (*this)[i+d] += (*this)[i] * c;
    }
    void divide(const int d, const T c) {
        int n = (*this).size();
        if (c == T(1)) for (int i=0; i<n-d; i++) (*this)[i+d] -= (*this)[i];
        else if (c == T(-1)) for (int i=0; i<n-d; i++) (*this)[i+d] += (*this)[i];
        else for (int i=0; i<n-d; i++) (*this)[i+d] -= (*this)[i] * c;
    }

    T eval(const T &a) const {
        T x(1), res(0);
        for (auto e : *this) res += e * x, x *= a;
        return res;
    }

    F operator*(const T &g) const { return F(*this) *= g; }
    F operator/(const T &g) const { return F(*this) /= g; }
    F operator+(const F &g) const { return F(*this) += g; }
    F operator-(const F &g) const { return F(*this) -= g; }
    F operator<<(const int d) const { return F(*this) <<= d; }
    F operator>>(const int d) const { return F(*this) >>= d; }
    F operator*(const F &g) const { return F(*this) *= g; }
    F operator/(const F &g) const { return F(*this) /= g; }
    F operator*(vector<pair<int, T>> g) const { return F(*this) *= g; }
    F operator/(vector<pair<int, T>> g) const { return F(*this) /= g; }
};

using mint = modint998244353;
using fps = FormalPowerSeries<mint>;

int main(){

    init();

    ll H, W, N, K;
    cin >> H >> W >> N >> K;

    auto func=[&](ll X)->vector<mint>{
        ll x, y;
        mint c;
        if (K*2 <= X) x = K-1, y = X-K*2+2;
        else x = X-K, y = X-(X-K)*2;
        vector<mint> ps(N+1), s(N+1);
        fps a(N+1), b(N+1);
        c = 1;
        for (int i=0; i<=N; i++){
            b[i] = finv[i+1];
            c *= x+1;
            a[i] = c * finv[i+1];
        }
        a *= b.inv();
        for (int i=0; i<=N; i++) ps[i] = a[i] * f[i];
        ps[0] -= 1;
        c = 1;
        for (int i=0; i<=N; i++){
            s[i] = ps[i] * 2 + c * y;
            c *= x+1;
        }
        return s;
    };

    mint ans, M, c;
    ans = mint(H) * W;
    M = -mint((H-K+1) * (W-K+1)).inv();
    c = 1;
    vector<mint> hs=func(H), ws=func(W);
    for (int i=0; i<=N; i++){
        ans -= c * C(N, i) * hs[i] * ws[i];
        c *= M;
    }

    cout << ans.val() << endl;

    return 0;
}
0