#line 1 "/home/y_midori/cp/test/test.test.cpp" #define PROBLEM "https://yukicoder.me/problems/no/2459" #line 2 "template.hpp" // #pragma GCC target("avx2") // #pragma GCC optimize("O3") // #pragma GCC optimize("unroll-loops") #include using namespace std; template concept Streamable = requires(ostream os, T &x) { os << x; }; template concept is_modint = requires(mint &x) { { x.val() } -> std::convertible_to; }; #ifdef LOCAL #include #else #define debug(...) #endif template void print_one(const T &value) { cout << value; } template void print_one(const T &value) { cout << value.val(); } void print() { cout << '\n'; } template void print(const T &a, const Ts &...b) { print_one(a); ((cout << ' ', print_one(b)), ...); cout << '\n'; } template requires(!Streamable) void print(const Iterable &v) { for(auto it = v.begin(); it != v.end(); ++it) { if(it != v.begin()) cout << " "; print_one(*it); } cout << '\n'; } using ll = long long; using vl = vector; using vll = vector; using P = pair; #define all(v) v.begin(), v.end() #define UNIQUE(v) ranges::sort(v), v.erase(unique(all(v)), end(v)) template inline bool chmax(T &a, T b) { return ((a < b) ? (a = b, true) : (false)); } template inline bool chmin(T &a, T b) { return ((a > b) ? (a = b, true) : (false)); } // https://trap.jp/post/1224/ template constexpr auto min(T... a) { return min(initializer_list>{a...}); } template constexpr auto max(T... a) { return max(initializer_list>{a...}); } template void input(T &...a) { (cin >> ... >> a); } template void input(vector &a) { for(T &x : a) cin >> x; } #define INT(...) \ int __VA_ARGS__; \ input(__VA_ARGS__) #define LL(...) \ long long __VA_ARGS__; \ input(__VA_ARGS__) #define STR(...) \ string __VA_ARGS__; \ input(__VA_ARGS__) #define REP1(a) for(ll i = 0; i < a; i++) #define REP2(i, a) for(ll i = 0; i < a; i++) #define REP3(i, a, b) for(ll i = a; i < b; i++) #define REP4(i, a, b, c) for(ll i = a; i < b; i += c) #define overload4(a, b, c, d, e, ...) e #define rep(...) overload4(__VA_ARGS__, REP4, REP3, REP2, REP1)(__VA_ARGS__) #define rep1(i, n) for(ll i = 1; i <= ((ll)n); ++i) ll inf = 3e18; vl dx = {1, -1, 0, 0}; vl dy = {0, 0, 1, -1}; #line 3 "math/factorial.hpp" // https://suisen-cp.github.io/cp-library-cpp/library/math/factorial.hpp template struct factorial { factorial() {}; void ensure(const int n) { int sz = size(fac); if(sz > n) { return; } int new_sz = max(2 * sz, n + 1); fac.resize(new_sz), fac_inv.resize(new_sz); for(int i = sz; i < new_sz; i++) { if(i == 0) { fac[i] = 1; continue; } fac[i] = fac[i - 1] * i; } fac_inv[new_sz - 1] = T(1) / fac[new_sz - 1]; for(int i = new_sz - 2; i >= sz; i--) { fac_inv[i] = fac_inv[i + 1] * (i + 1); } return; } T get(int i) { ensure(i); return fac[i]; } T operator[](int i) { return get(i); } T inv(int i) { ensure(i); return fac_inv[i]; } T binom(int n, int i) { if(n < 0 || i < 0 || n < i) { return T(0); } ensure(n); return fac[n] * fac_inv[i] * fac_inv[n - i]; } T perm(int n, int i) { if(n < 0 || i < 0 || n < i) { return T(0); } ensure(n); return fac[n] * fac_inv[n - i]; } private: vector fac, fac_inv; }; #line 3 "poly/formal-power-series.hpp" #include // 10^9+7みたいなときconvolutionどうする? template struct FormalPowerSeries : vector { using vector::vector; using FPS = FormalPowerSeries; FormalPowerSeries(const vector &v) : vector(v) {} FPS &operator+=(const FPS &f) { if(this->size() < f.size()) this->resize(f.size()); for(int i = 0; i < ssize(f); ++i) (*this)[i] += f[i]; return *this; } FPS &operator-=(const FPS &f) { if(this->size() < f.size()) this->resize(f.size()); for(int i = 0; i < ssize(f); ++i) (*this)[i] -= f[i]; return *this; } FPS &operator*=(const FPS &f) { return (*this) = atcoder::convolution(*this, f); } FPS &operator*=(const mint &x) { for(mint &vi : *this) vi *= x; return *this; } FPS operator+(const FPS &f) const { return FPS(*this) += f; } FPS operator-(const FPS &f) const { return FPS(*this) -= f; } FPS operator*(const FPS &f) const { return FPS(*this) *= f; } FPS operator*(const mint &x) const { return FPS(*this) *= x; } FPS operator-() const { FPS res = *this; for(mint &vi : res) { vi = -vi; } return res; } FPS operator>>(const int sz) const { if(sz >= ssize(*this)) return {}; FPS res(begin(*this) + sz, end(*this)); return res; } FPS operator<<(const int sz) const { FPS res(sz, 0); res.insert(end(res), begin(*this), end(*this)); return res; } FPS inv(int deg = -1) const { assert(!this->empty() and (*this)[0] != mint(0)); if(deg == -1) deg = this->size(); FPS res = {(*this)[0].inv()}; FPS f; f.reserve(this->size()); for(int d = 1; d < deg << 1; d <<= 1) { while(ssize(f) < min(ssize(*this), d)) f.emplace_back((*this)[f.size()]); res *= (FPS({2}) - f * res); while(ssize(res) > min(d, deg)) res.pop_back(); } return res; } // なければ空を返す // 定数項が1でないときget_sqrtを渡す。解が複数ありうることに注意 FPS sqrt( int deg = -1, function get_sqrt = [](mint) { return mint(1); }) const { if(this->empty()) return {}; if(deg == -1) deg = this->size(); if((*this)[0] == mint(0)) { for(int i = 1; i < ssize(*this); ++i) { if((*this)[i] == mint(0)) continue; if(i & 1) return {}; if(i / 2 >= deg) break; FPS res = (*this >> i).sqrt(deg - i / 2, get_sqrt); if(res.empty()) return {}; res = res << (i / 2); return res; } return FPS(deg, 0); } FPS res{get_sqrt((*this)[0])}; if(res[0] * res[0] != (*this)[0]) return {}; FPS f; f.reserve(this->size()); mint inv2 = mint(1) / mint(2); for(int d = 1; d < deg << 1; d <<= 1) { while(ssize(f) < min(ssize(*this), d)) f.emplace_back((*this)[f.size()]); res = (res + f * res.inv(d)) * inv2; while(ssize(res) > min(d, deg)) res.pop_back(); } return res; } FPS diff() const { FPS res(max(0, ssize(*this) - 1)); for(int i = 1; i < ssize(*this); ++i) res[i - 1] = (mint)i * (*this)[i]; return res; } FPS integral() const { FPS res(ssize(*this) + 1); for(int i = 0; i < ssize(*this); ++i) res[i + 1] = (*this)[i] / mint(i + 1); return res; } FPS log(int deg = -1) const { assert(!this->empty() and (*this)[0] == (mint)1); if(deg == -1) deg = this->size(); if(deg == 0) return {}; FPS t(begin(*this), begin(*this) + min(deg, ssize(*this))); FPS res = t.diff() * t.inv(deg - 1); res.resize(deg - 1); return res.integral(); } FPS exp(int deg = -1) { assert(!this->empty() and (*this)[0] == (mint)0); if(deg == -1) deg = this->size(); if(deg == 0) return {}; FPS res = {1}; FPS f; f.reserve(this->size()); for(int d = 1; d < deg << 1; d <<= 1) { while(ssize(f) < min(ssize(*this), d)) f.emplace_back((*this)[f.size()]); res *= (FPS({1}) + f - res.log(d)); while(ssize(res) > min(d, deg)) res.pop_back(); } return res; } }; #line 3 "poly/sum-of-powers.hpp" /** * @brief 列の冪乗和 * @see https://yukicoder.me/problems/no/1145/editorial */ /// 各i(0≦i≦k)についてsum[j]a_j^iを求め、長さk+1の列を返す /// O(n*log(n)^2 + k*log(k))時間 template vector sum_of_powers(const vector &a, int k) { if(a.empty()) return vector(k + 1, 0); queue> que; for(auto &ai : a) que.push({1, -ai}); while(que.size() > 1) { auto f = que.front(); que.pop(); auto g = que.front(); que.pop(); que.push(f * g); } auto &f = que.front(); f = f.log(k + 1); for(int i = 1; i <= k; ++i) f[i] = -f[i] * mint(i); f[0] = ssize(a); return f; } /// 各i(0≦i≦k)についてsum[j in (0,n]]j^iを求め、長さk+1の列を返す /// O(n*log(n)^2 + k*log(k))時間 ←修正 template vector sum_of_powers_iota(int n, int k) { using FPS = FormalPowerSeries; FPS res = (FPS({0, n}).exp(k + 2) >> (1)) * (FPS({0, 1}).exp(k + 2) >> (1)).inv(k + 1); res.resize(k + 1); mint fac = 1; for(int i = 0; i <= k; ++i) { res[i] *= fac; fac *= i + 1; } debug(n, k, res); return res; } #line 5 "/home/y_midori/cp/test/test.test.cpp" #include using mint = atcoder::modint998244353; factorial fac; void solve() { INT(h, w, n, k); mint ans = 0; int imax = min(k, h - k), jmax = min(k, w - k); mint linv = mint(mint(h - k + 1) * (w - k + 1)).inv(); // 四隅 ans += mint(4) * imax * jmax; // mint wi, wj; vector sh = sum_of_powers_iota(imax + 1, n), sw = sum_of_powers_iota(jmax + 1, n); sh[0]--, sw[0]--; debug(imax, n, sh); for(int m = 0; m <= n; ++m) { ans -= 4 * fac.binom(n, m) * (-linv).pow(m) * sh[m] * sw[m]; } mint hl = abs(h - 2 * k), wl = abs(w - 2 * k); mint hr = (h >= 2 * k ? k : h - k + 1), wr = (w >= 2 * k ? k : w - k + 1); // 中 mint p = 1 - linv * hr * wr; ans += hl * wl * (1 - p.pow(n)); // 上下左右 ans += 2 * (wl * imax + hl * jmax); for(int m = 0; m <= n; ++m) { ans -= 2 * fac.binom(n, m) * (-linv).pow(m) * (hl * (hr.pow(m)) * sw[m] + wl * (wr.pow(m)) * sh[m]); } print(ans); } int main() { solve(); }