結果
問題 | No.2303 Frog on Grid |
ユーザー | ei1333333 |
提出日時 | 2023-05-12 21:52:07 |
言語 | C++17(gcc12) (gcc 12.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 125 ms / 2,000 ms |
コード長 | 12,914 bytes |
コンパイル時間 | 2,806 ms |
コンパイル使用メモリ | 216,720 KB |
実行使用メモリ | 12,344 KB |
最終ジャッジ日時 | 2024-11-28 17:58:02 |
合計ジャッジ時間 | 4,912 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 20 |
ソースコード
#include<bits/stdc++.h> using namespace std; using int64 = long long; // const int mod = 1e9 + 7; const int mod = 998244353; const int64 infll = (1LL << 62) - 1; const int inf = (1 << 30) - 1; /* struct IoSetup { IoSetup() { cin.tie(nullptr); ios::sync_with_stdio(false); cout << fixed << setprecision(10); cerr << fixed << setprecision(10); } } iosetup; */ template< typename T1, typename T2 > ostream &operator<<(ostream &os, const pair< T1, T2 > &p) { os << p.first << " " << p.second; return os; } template< typename T1, typename T2 > istream &operator>>(istream &is, pair< T1, T2 > &p) { is >> p.first >> p.second; return is; } template< typename T > ostream &operator<<(ostream &os, const vector< T > &v) { for (int i = 0; i < (int) v.size(); i++) { os << v[i] << (i + 1 != v.size() ? " " : ""); } return os; } template< typename T > istream &operator>>(istream &is, vector< T > &v) { for (T &in: v) is >> in; return is; } template< typename T1, typename T2 > inline bool chmax(T1 &a, T2 b) { return a < b && (a = b, true); } template< typename T1, typename T2 > inline bool chmin(T1 &a, T2 b) { return a > b && (a = b, true); } template< typename T = int64 > vector< T > make_v(size_t a) { return vector< T >(a); } template< typename T, typename... Ts > auto make_v(size_t a, Ts... ts) { return vector< decltype(make_v< T >(ts...)) >(a, make_v< T >(ts...)); } template< typename T, typename V > typename enable_if< is_class< T >::value == 0 >::type fill_v(T &t, const V &v) { t = v; } template< typename T, typename V > typename enable_if< is_class< T >::value != 0 >::type fill_v(T &t, const V &v) { for (auto &e: t) fill_v(e, v); } template< typename F > struct FixPoint: F { FixPoint(F &&f): F(forward< F >(f)) {} template< typename... Args > decltype(auto) operator()(Args &&... args) const { return F::operator()(*this, forward< Args >(args)...); } }; template< typename F > inline decltype(auto) MFP(F &&f) { return FixPoint< F >{forward< F >(f)}; } #line 1 "math/combinatorics/montgomery-mod-int.hpp" /** * @brief Montgomery ModInt */ template< uint32_t mod, bool fast = false > struct MontgomeryModInt { using mint = MontgomeryModInt; using i32 = int32_t; using i64 = int64_t; using u32 = uint32_t; using u64 = uint64_t; static constexpr u32 get_r() { u32 ret = mod; for (i32 i = 0; i < 4; i++) ret *= 2 - mod * ret; return ret; } static constexpr u32 r = get_r(); static constexpr u32 n2 = -u64(mod) % mod; static_assert(r * mod == 1, "invalid, r * mod != 1"); static_assert(mod < (1 << 30), "invalid, mod >= 2 ^ 30"); static_assert((mod & 1) == 1, "invalid, mod % 2 == 0"); u32 x; MontgomeryModInt(): x{} {} MontgomeryModInt(const i64 &a) : x(reduce(u64(fast ? a : (a % mod + mod)) * n2)) {} static constexpr u32 reduce(const u64 &b) { return u32(b >> 32) + mod - u32((u64(u32(b) * r) * mod) >> 32); } mint &operator+=(const mint &p) { if (i32(x += p.x - 2 * mod) < 0) x += 2 * mod; return *this; } mint &operator-=(const mint &p) { if (i32(x -= p.x) < 0) x += 2 * mod; return *this; } mint &operator*=(const mint &p) { x = reduce(u64(x) * p.x); return *this; } mint &operator/=(const mint &p) { *this *= p.inverse(); return *this; } mint operator-() const { return mint() - *this; } mint operator+(const mint &p) const { return mint(*this) += p; } mint operator-(const mint &p) const { return mint(*this) -= p; } mint operator*(const mint &p) const { return mint(*this) *= p; } mint operator/(const mint &p) const { return mint(*this) /= p; } bool operator==(const mint &p) const { return (x >= mod ? x - mod : x) == (p.x >= mod ? p.x - mod : p.x); } bool operator!=(const mint &p) const { return (x >= mod ? x - mod : x) != (p.x >= mod ? p.x - mod : p.x); } u32 get() const { u32 ret = reduce(x); return ret >= mod ? ret - mod : ret; } mint pow(u64 n) const { mint ret(1), mul(*this); while (n > 0) { if (n & 1) ret *= mul; mul *= mul; n >>= 1; } return ret; } mint inverse() const { return pow(mod - 2); } friend ostream &operator<<(ostream &os, const mint &p) { return os << p.get(); } friend istream &operator>>(istream &is, mint &a) { i64 t; is >> t; a = mint(t); return is; } static u32 get_mod() { return mod; } }; using modint = MontgomeryModInt< mod >; #line 1 "math/combinatorics/enumeration.hpp" /** * @brief Enumeration(組み合わせ) */ template< typename T > struct Enumeration { private: static vector< T > _fact, _finv, _inv; inline static void expand(size_t sz) { if (_fact.size() < sz + 1) { int pre_sz = max(1, (int) _fact.size()); _fact.resize(sz + 1, T(1)); _finv.resize(sz + 1, T(1)); _inv.resize(sz + 1, T(1)); for (int i = pre_sz; i <= (int) sz; i++) { _fact[i] = _fact[i - 1] * T(i); } _finv[sz] = T(1) / _fact[sz]; for (int i = (int) sz - 1; i >= pre_sz; i--) { _finv[i] = _finv[i + 1] * T(i + 1); } for (int i = pre_sz; i <= (int) sz; i++) { _inv[i] = _finv[i] * _fact[i - 1]; } } } public: explicit Enumeration(size_t sz = 0) { expand(sz); } static inline T fact(int k) { expand(k); return _fact[k]; } static inline T finv(int k) { expand(k); return _finv[k]; } static inline T inv(int k) { expand(k); return _inv[k]; } static T P(int n, int r) { if (r < 0 || n < r) return 0; return fact(n) * finv(n - r); } static T C(int p, int q) { if (q < 0 || p < q) return 0; return fact(p) * finv(q) * finv(p - q); } static T H(int n, int r) { if (n < 0 || r < 0) return 0; return r == 0 ? 1 : C(n + r - 1, r); } }; template< typename T > vector< T > Enumeration< T >::_fact = vector< T >(); template< typename T > vector< T > Enumeration< T >::_finv = vector< T >(); template< typename T > vector< T > Enumeration< T >::_inv = vector< T >(); #line 1 "math/fft/number-theoretic-transform-friendly-mod-int.hpp" /** * @brief Number Theoretic Transform Friendly ModInt */ template< typename Mint > struct NumberTheoreticTransformFriendlyModInt { static vector< Mint > roots, iroots, rate3, irate3; static int max_base; NumberTheoreticTransformFriendlyModInt() = default; static void init() { if (roots.empty()) { const unsigned mod = Mint::get_mod(); assert(mod >= 3 && mod % 2 == 1); auto tmp = mod - 1; max_base = 0; while (tmp % 2 == 0) tmp >>= 1, max_base++; Mint root = 2; while (root.pow((mod - 1) >> 1) == 1) { root += 1; } assert(root.pow(mod - 1) == 1); roots.resize(max_base + 1); iroots.resize(max_base + 1); rate3.resize(max_base + 1); irate3.resize(max_base + 1); roots[max_base] = root.pow((mod - 1) >> max_base); iroots[max_base] = Mint(1) / roots[max_base]; for (int i = max_base - 1; i >= 0; i--) { roots[i] = roots[i + 1] * roots[i + 1]; iroots[i] = iroots[i + 1] * iroots[i + 1]; } { Mint prod = 1, iprod = 1; for (int i = 0; i <= max_base - 3; i++) { rate3[i] = roots[i + 3] * prod; irate3[i] = iroots[i + 3] * iprod; prod *= iroots[i + 3]; iprod *= roots[i + 3]; } } } } static void ntt(vector< Mint > &a) { init(); const int n = (int) a.size(); assert((n & (n - 1)) == 0); int h = __builtin_ctz(n); assert(h <= max_base); int len = 0; Mint imag = roots[2]; if (h & 1) { int p = 1 << (h - 1); Mint rot = 1; for (int i = 0; i < p; i++) { auto r = a[i + p]; a[i + p] = a[i] - r; a[i] += r; } len++; } for (; len + 1 < h; len += 2) { int p = 1 << (h - len - 2); { // s = 0 for (int i = 0; i < p; i++) { auto a0 = a[i]; auto a1 = a[i + p]; auto a2 = a[i + 2 * p]; auto a3 = a[i + 3 * p]; auto a1na3imag = (a1 - a3) * imag; auto a0a2 = a0 + a2; auto a1a3 = a1 + a3; auto a0na2 = a0 - a2; a[i] = a0a2 + a1a3; a[i + 1 * p] = a0a2 - a1a3; a[i + 2 * p] = a0na2 + a1na3imag; a[i + 3 * p] = a0na2 - a1na3imag; } } Mint rot = rate3[0]; for (int s = 1; s < (1 << len); s++) { int offset = s << (h - len); Mint rot2 = rot * rot; Mint rot3 = rot2 * rot; for (int i = 0; i < p; i++) { auto a0 = a[i + offset]; auto a1 = a[i + offset + p] * rot; auto a2 = a[i + offset + 2 * p] * rot2; auto a3 = a[i + offset + 3 * p] * rot3; auto a1na3imag = (a1 - a3) * imag; auto a0a2 = a0 + a2; auto a1a3 = a1 + a3; auto a0na2 = a0 - a2; a[i + offset] = a0a2 + a1a3; a[i + offset + 1 * p] = a0a2 - a1a3; a[i + offset + 2 * p] = a0na2 + a1na3imag; a[i + offset + 3 * p] = a0na2 - a1na3imag; } rot *= rate3[__builtin_ctz(~s)]; } } } static void intt(vector< Mint > &a, bool f = true) { init(); const int n = (int) a.size(); assert((n & (n - 1)) == 0); int h = __builtin_ctz(n); assert(h <= max_base); int len = h; Mint iimag = iroots[2]; for (; len > 1; len -= 2) { int p = 1 << (h - len); { // s = 0 for (int i = 0; i < p; i++) { auto a0 = a[i]; auto a1 = a[i + 1 * p]; auto a2 = a[i + 2 * p]; auto a3 = a[i + 3 * p]; auto a2na3iimag = (a2 - a3) * iimag; auto a0na1 = a0 - a1; auto a0a1 = a0 + a1; auto a2a3 = a2 + a3; a[i] = a0a1 + a2a3; a[i + 1 * p] = (a0na1 + a2na3iimag); a[i + 2 * p] = (a0a1 - a2a3); a[i + 3 * p] = (a0na1 - a2na3iimag); } } Mint irot = irate3[0]; for (int s = 1; s < (1 << (len - 2)); s++) { int offset = s << (h - len + 2); Mint irot2 = irot * irot; Mint irot3 = irot2 * irot; for (int i = 0; i < p; i++) { auto a0 = a[i + offset]; auto a1 = a[i + offset + 1 * p]; auto a2 = a[i + offset + 2 * p]; auto a3 = a[i + offset + 3 * p]; auto a2na3iimag = (a2 - a3) * iimag; auto a0na1 = a0 - a1; auto a0a1 = a0 + a1; auto a2a3 = a2 + a3; a[i + offset] = a0a1 + a2a3; a[i + offset + 1 * p] = (a0na1 + a2na3iimag) * irot; a[i + offset + 2 * p] = (a0a1 - a2a3) * irot2; a[i + offset + 3 * p] = (a0na1 - a2na3iimag) * irot3; } irot *= irate3[__builtin_ctz(~s)]; } } if (len >= 1) { int p = 1 << (h - 1); for (int i = 0; i < p; i++) { auto ajp = a[i] - a[i + p]; a[i] += a[i + p]; a[i + p] = ajp; } } if (f) { Mint inv_sz = Mint(1) / n; for (int i = 0; i < n; i++) a[i] *= inv_sz; } } static vector< Mint > multiply(vector< Mint > a, vector< Mint > b) { int need = a.size() + b.size() - 1; int nbase = 1; while ((1 << nbase) < need) nbase++; int sz = 1 << nbase; a.resize(sz, 0); b.resize(sz, 0); ntt(a); ntt(b); Mint inv_sz = Mint(1) / sz; for (int i = 0; i < sz; i++) a[i] *= b[i] * inv_sz; intt(a, false); a.resize(need); return a; } }; template< typename Mint > vector< Mint > NumberTheoreticTransformFriendlyModInt< Mint >::roots = vector< Mint >(); template< typename Mint > vector< Mint > NumberTheoreticTransformFriendlyModInt< Mint >::iroots = vector< Mint >(); template< typename Mint > vector< Mint > NumberTheoreticTransformFriendlyModInt< Mint >::rate3 = vector< Mint >(); template< typename Mint > vector< Mint > NumberTheoreticTransformFriendlyModInt< Mint >::irate3 = vector< Mint >(); template< typename Mint > int NumberTheoreticTransformFriendlyModInt< Mint >::max_base = 0; int main() { int H, W; cin >> H >> W; vector< modint > f(H + 1); for (int i = 1; i <= H; i++) { // x+y=i // x+2y=H int x = -H + 2 * i; int y = H - i; if (0 <= x and 0 <= y) { f[i] = Enumeration< modint >::C(x + y, x); } f[i] *= Enumeration< modint >::finv(i); } vector< modint > g(W + 1); for (int i = 1; i <= W; i++) { // x+y=i // x+2y=H int x = -W + 2 * i; int y = W - i; if (0 <= x and 0 <= y) { g[i] = Enumeration< modint >::C(x + y, x); } g[i] *= Enumeration< modint >::finv(i); } auto v = NumberTheoreticTransformFriendlyModInt< modint >::multiply(f, g); modint ret = 0; for (int i = 1; i < v.size(); i++) { ret += Enumeration< modint >::fact(i) * v[i]; } cout << ret << "\n"; }