結果

問題 No.2683 Two Sheets
ユーザー downerdowner
提出日時 2024-03-20 22:19:58
言語 C++23
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 265 ms / 2,000 ms
コード長 5,520 bytes
コンパイル時間 2,820 ms
コンパイル使用メモリ 255,344 KB
実行使用メモリ 22,864 KB
最終ジャッジ日時 2024-09-30 08:16:49
合計ジャッジ時間 5,431 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 12 ms
5,248 KB
testcase_03 AC 201 ms
17,760 KB
testcase_04 AC 30 ms
5,760 KB
testcase_05 AC 112 ms
17,536 KB
testcase_06 AC 121 ms
15,360 KB
testcase_07 AC 70 ms
12,928 KB
testcase_08 AC 265 ms
22,272 KB
testcase_09 AC 99 ms
13,096 KB
testcase_10 AC 184 ms
22,272 KB
testcase_11 AC 152 ms
17,920 KB
testcase_12 AC 186 ms
17,380 KB
testcase_13 AC 87 ms
22,864 KB
testcase_14 AC 241 ms
22,616 KB
testcase_15 AC 134 ms
15,616 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

// modint
// using mint = MontgomeryModInt64; と定義する
// mint::set_mod(M); でmodをMに設定する
struct MontgomeryModInt64 {
    using mint = MontgomeryModInt64;
    using u64 = uint64_t;
    using u128 = __uint128_t;

    // static変数
    // R = 2 ^ 64
    static inline u64 MOD;
    static inline u64 INV_MOD;  // INV_MOD * MOD ≡ 1 (mod 2 ^ 64)
    static inline u64 T128;     // 2 ^ 128 (mod MOD)

    u64 val;

    // コンストラクタ
    // MODを足す?
    MontgomeryModInt64(): val(0) {}
    MontgomeryModInt64(long long v): val(MR((u128(v) + MOD) * T128)) {}

    // 値を返す
    u64 get() const {
        u64 res = MR(val);
        return res >= MOD ? res - MOD : res;
    }

    // static関数
    static u64 get_mod() { return MOD; }
    static void set_mod(u64 mod) {
        MOD = mod;
        T128 = -u128(mod) % mod;
        INV_MOD = get_inv_mod();
    }
    // ニュートン法で逆元を求める
    static u64 get_inv_mod() {
        u64 res = MOD;
        for(int i = 0; i < 5; ++i) res *= 2 - MOD * res;
        return res;
    }
    // モンゴメリリダクション
    static u64 MR(const u128& v) {
        return (v + u128(u64(v) * u64(-INV_MOD)) * MOD) >> 64;
    }

    // 算術演算子
    mint operator - () const { return mint() - mint(*this); }

    mint operator + (const mint& r) const { return mint(*this) += r; }
    mint operator - (const mint& r) const { return mint(*this) -= r; }
    mint operator * (const mint& r) const { return mint(*this) *= r; }
    mint operator / (const mint& r) const { return mint(*this) /= r; }

    mint& operator += (const mint& r) {
        if((val += r.val) >= 2 * MOD) val -= 2 * MOD;
        return *this;
    }
    mint& operator -= (const mint& r) {
        if((val += 2 * MOD - r.val) >= 2 * MOD) val -= 2 * MOD;
        return *this;
    }
    mint& operator *= (const mint& r) {
        val = MR(u128(val) * r.val);
        return *this;
    }
    mint& operator /= (const mint& r) {
        *this *= r.inv();
        return *this;
    }

    mint inv() const { return pow(MOD - 2); }
    mint pow(u128 n) const {
        mint res(1), mul(*this);
        while(n > 0) {
            if(n & 1) res *= mul;
            mul *= mul;
            n >>= 1;
        }
        return res;
    }

    // その他演算子
    bool operator == (const mint& r) const {
        return (val >= MOD ? val - MOD : val) == (r.val >= MOD ? r.val - MOD : r.val);
    }
    bool operator != (const mint& r) const {
        return (val >= MOD ? val - MOD : val) != (r.val >= MOD ? r.val - MOD : r.val);
    }

    // 入力
    friend istream& operator >> (istream& is, mint& x) {
        long long t;
        is >> t;
        x = mint(t);
        return is;
    }
    // 出力
    friend ostream& operator << (ostream& os, const mint& x) {
        return os << x.get();
    }
    friend mint modpow(const mint& r, long long n) {
        return r.pow(n);
    } 
    friend mint modinv(const mint& r) {
        return r.inv();
    }
};

using mint = MontgomeryModInt64;

using S = mint;
S op(S l, S r) { return l + r; }
S e = 0;

using F = mint;
S mapping(F f, S x) {return f + x;}
F composition(F f, F g) {return f + g;}
F id = 0;


struct LazySegmentTree{
    int n;
    vector<S> data;
    vector<F> lazy;

    LazySegmentTree(vector<S> &v) {
        int sz = v.size();
        n = 1; while(n < sz) n *= 2;
        data.resize(2 * n, e);
        lazy.resize(2 * n, id);

        for(int i = 0; i < sz; i++) {
            data[i+n] = v[i];
        }
        for(int i = n - 1; i > 0; i--) {
            data[i] = op(data[i<<1], data[i<<1|1]);
        }
    }

    void eval(int i) {
        if(lazy[i] == id) return;
        data[i] = mapping(lazy[i], data[i]);

        if(i < n) {
            lazy[i<<1] = composition(lazy[i], lazy[i<<1]);
            lazy[i<<1|1] = composition(lazy[i], lazy[i<<1|1]);
        }

        lazy[i] = id;
    }

    void apply(int a, int b, F x, int i=1, int l=0, int r=-1) {
        if(r < 0) r = n;
        eval(i);
        if(r <= a || b <= l) return;
        if(a <= l && r <= b) {
            lazy[i] = x;
            eval(i);
        }
        else {
            int mid = (l + r) / 2;
            apply(a, b, x, i << 1, l, mid);
            apply(a, b, x, i << 1 | 1, mid, r);
            data[i] = op(data[i<<1], data[i<<1|1]);
        }
    }

    S prod(int a, int b, int i=1, int l=0, int r=-1) {
        if(r < 0) r = n;
        if(r <= a || b <= l) return e;

        eval(i);
        if(a <= l && r <= b) return data[i];
        int mid = (l + r) / 2;
        return op(prod(a, b, i << 1, l, mid), prod(a, b, i << 1 | 1, mid, r));
    }
};

int main() {
    mint::set_mod(998244353);

    int H, W, A, B;
    cin >> H >> W >> A >> B;

    mint ans = mint(A) * B * 2;

    // 重なる区間の期待値はX1 * Y1
    // X1は重なる区間(縦),Y1は重なる区間(横)の期待値
    
    vector<mint> vx(H, 0), vy(W, 0);
    LazySegmentTree X(vx), Y(vy);

    for(int i = 0; i < H - A + 1; i++) {
        X.apply(i, i + A, mint(H - A + 1).inv());
    }
    for(int i = 0; i < W - B + 1; i++) {
        Y.apply(i, i + B, mint(W - B + 1).inv());
    }

    mint overx = 0;
    for(int i = 0; i < H; i++) {
        overx += X.prod(i, i + 1).pow(2);
    }
    mint overy = 0;
    for(int i = 0; i < W; i++) {
        overy += Y.prod(i, i + 1).pow(2);
    }

    ans -= overx * overy;
    cout << ans << endl;

    return 0;
}
0