結果

問題 No.2683 Two Sheets
ユーザー srjywrdnprktsrjywrdnprkt
提出日時 2024-03-29 12:09:21
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 27 ms / 2,000 ms
コード長 6,803 bytes
コンパイル時間 3,985 ms
コンパイル使用メモリ 249,780 KB
実行使用メモリ 11,172 KB
最終ジャッジ日時 2024-03-29 12:09:28
合計ジャッジ時間 5,677 ms
ジャッジサーバーID
(参考情報)
judge12 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 15 ms
11,172 KB
testcase_01 AC 15 ms
11,172 KB
testcase_02 AC 15 ms
11,172 KB
testcase_03 AC 14 ms
11,172 KB
testcase_04 AC 15 ms
11,172 KB
testcase_05 AC 15 ms
11,172 KB
testcase_06 AC 15 ms
11,172 KB
testcase_07 AC 27 ms
11,172 KB
testcase_08 AC 15 ms
11,172 KB
testcase_09 AC 15 ms
11,172 KB
testcase_10 AC 15 ms
11,172 KB
testcase_11 AC 15 ms
11,172 KB
testcase_12 AC 14 ms
11,172 KB
testcase_13 AC 14 ms
11,172 KB
testcase_14 AC 14 ms
11,172 KB
testcase_15 AC 14 ms
11,172 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/convolution>

using namespace std;
using namespace atcoder;
using ll = long long;

using mint = modint998244353;

const ll MAX = 1e6+10;
vector<mint> f, finv;

mint inv(mint x){
    mint ans = 1;
    ll e = 998244351;

    while (e > 0){
        if ((e & 1LL)) ans *= x;
        e = e >> 1LL;
        x *= x;
    }

    return ans;
}

void init(){
    f.resize(MAX+1); finv.resize(MAX+1);
    f[0] = 1;
    for (int i=1; i<=MAX; i++) f[i] = f[i-1]*i;
    finv[MAX] = inv(f[MAX]);
    for (int i=MAX-1; i>=0; i--) finv[i] = finv[i+1] * (i+1);
}

mint C(ll n, ll k){
    if (n < k || k < 0) return 0;

    return f[n] * finv[k] * finv[n-k] ;
}

template<class T>
struct FormalPowerSeries : vector<T> {
    using vector<T>::vector;
    using vector<T>::operator=;
    using F = FormalPowerSeries;

    F operator-() const {
        F res(*this);
        for (auto &e : res) e = -e;
        return res;
    }
    F &operator*=(const T &g) {
        for (auto &e : *this) e *= g;
        return *this;
    }
    F &operator/=(const T &g) {
        assert(g != T(0));
        *this *= g.inv();
        return *this;
    }
    F &operator+=(const F &g) {
        int n = (*this).size(), m = g.size();
        for (int i=0; i<min(n, m); i++) (*this)[i] += g[i];
        return *this;
    }
    F &operator-=(const F &g) {
        int n = (*this).size(), m = g.size();
        for (int i=0; i<min(n, m); i++) (*this)[i] -= g[i];
        return *this;
    }
    F &operator<<=(const int d) {
        int n = (*this).size();
        (*this).insert((*this).begin(), d, 0);
        (*this).resize(n);
        return *this;
    }
    F &operator>>=(const int d) {
        int n = (*this).size();
        (*this).erase((*this).begin(), (*this).begin() + min(n, d));
        (*this).resize(n);
        return *this;
    }
    F inv(int d = -1) const {
        int n = (*this).size();
        assert(n != 0 && (*this)[0] != 0);
        if (d == -1) d = n;
        assert(d > 0);
        F res{(*this)[0].inv()};
        while (res.size() < d) {
            int m = size(res);
            F f(begin(*this), begin(*this) + min(n, 2*m));
            F r(res);
            f.resize(2*m), internal::butterfly(f);
            r.resize(2*m), internal::butterfly(r);
            for (int i=0; i<2*m; i++) f[i] *= r[i];
            internal::butterfly_inv(f);
            f.erase(f.begin(), f.begin() + m);
            f.resize(2*m), internal::butterfly(f);
            for (int i=0; i<2*m; i++) f[i] *= r[i];
            internal::butterfly_inv(f);
            T iz = T(2*m).inv(); iz *= -iz;
            for (int i=0; i<m; i++) f[i] *= iz;
            res.insert(res.end(), f.begin(), f.begin() + m);
        }
        return {res.begin(), res.begin() + d};
    }

    // fast: FMT-friendly modulus only
    F &operator*=(const F &g) {
        int n = (*this).size();
        *this = convolution(*this, g);
        (*this).resize(n);
        return *this;
    }
    F &operator/=(const F &g) {
        int n = (*this).size();
        *this = convolution(*this, g.inv(n));
        (*this).resize(n);
        return *this;
    }

    // sparse
    F &operator*=(vector<pair<int, T>> g) {
        int n = (*this).size();
        auto [d, c] = g.front();
        if (d == 0) g.erase(g.begin());
        else c = 0;
        for (int i=n-1; i>=0; i--){
            (*this)[i] *= c;
            for (auto &[j, b] : g) {
                if (j > i) break;
                (*this)[i] += (*this)[i-j] * b;
            }
        }
        return *this;
    }
    F &operator/=(vector<pair<int, T>> g) {
        int n = (*this).size();
        auto [d, c] = g.front();
        assert(d == 0 && c != T(0));
        T ic = c.inv();
        g.erase(g.begin());
        for (int i=0; i<n; i++){
            for (auto &[j, b] : g) {
            if (j > i) break;
            (*this)[i] -= (*this)[i-j] * b;
            }
            (*this)[i] *= ic;
        }
        return *this;
    }

    // multiply and divide (1 + cz^d)
    void multiply(const int d, const T c) {
        int n = (*this).size();
        if (c == T(1)) for (int i=n-d-1; i>=0; i--) (*this)[i+d] += (*this)[i];
        else if (c == T(-1)) for (int i=n-d-1; i>=0; i--) (*this)[i+d] -= (*this)[i];
        else for (int i=n-d-1; i>=0; i--) (*this)[i+d] += (*this)[i] * c;
    }
    void divide(const int d, const T c) {
        int n = (*this).size();
        if (c == T(1)) for (int i=0; i<n-d; i++) (*this)[i+d] -= (*this)[i];
        else if (c == T(-1)) for (int i=0; i<n-d; i++) (*this)[i+d] += (*this)[i];
        else for (int i=0; i<n-d; i++) (*this)[i+d] -= (*this)[i] * c;
    }

    T eval(const T &a) const {
        T x(1), res(0);
        for (auto e : *this) res += e * x, x *= a;
        return res;
    }

    F operator*(const T &g) const { return F(*this) *= g; }
    F operator/(const T &g) const { return F(*this) /= g; }
    F operator+(const F &g) const { return F(*this) += g; }
    F operator-(const F &g) const { return F(*this) -= g; }
    F operator<<(const int d) const { return F(*this) <<= d; }
    F operator>>(const int d) const { return F(*this) >>= d; }
    F operator*(const F &g) const { return F(*this) *= g; }
    F operator/(const F &g) const { return F(*this) /= g; }
    F operator*(vector<pair<int, T>> g) const { return F(*this) *= g; }
    F operator/(vector<pair<int, T>> g) const { return F(*this) /= g; }
};

using mint = modint998244353;
using fps = FormalPowerSeries<mint>;

int main(){

    init();

    ll H, W, N=2, hx, wx, hy, wy, A, B;
    mint c, ans, M;
    cin >> H >> W >> A >> B;
    vector<mint> hps(N+1), wps(N+1), hs(N+1), ws(N+1);
    fps a(N+1), b(N+1);

    if (A*2 <= H){
        hx = A-1; hy = H-A*2+2;
    } 
    else{
        hx = H-A; hy = H-(H-A)*2;
    }
    if (B*2 <= W){
        wx = B-1; wy = W-B*2+2;
    }   
    else{
        wx = W-B; wy = W-(W-B)*2;
    } 

    c = 1;
    for (int i=0; i<=N; i++){
        b[i] = finv[i+1];
        c *= hx+1;
        a[i] = c * finv[i+1];
    }
    a *= b.inv();
    for (int i=0; i<=N; i++) hps[i] = a[i] * f[i];
    hps[0] -= 1;

    c = 1;
    for (int i=0; i<=N; i++){
        b[i] = finv[i+1];
        c *= wx+1;
        a[i] = c * finv[i+1];
    }
    a *= b.inv();
    for (int i=0; i<=N; i++) wps[i] = a[i] * f[i];
    wps[0] -= 1;

    ans = (hx * 2 + hy) * (wx * 2 + wy); 

    c = 1;
    for (int i=0; i<=N; i++){
        hs[i] = hps[i] * 2 + c * hy;
        c *= hx+1;
    }

    c = 1;
    for (int i=0; i<=N; i++){
        ws[i] = wps[i] * 2 + c * wy;
        c *= wx+1;
    }

    M = -mint((H-A+1) * (W-B+1)).inv();
    c = 1;
    for (int i=0; i<=N; i++){
        ans -= c * C(N, i) * hs[i] * ws[i];
        c *= M;
    }

    cout << ans.val() << endl;

    return 0;
}
0