結果
問題 | No.3095 Many Min Problems |
ユーザー |
![]() |
提出日時 | 2025-04-06 15:43:56 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 157 ms / 2,000 ms |
コード長 | 18,859 bytes |
コンパイル時間 | 3,261 ms |
コンパイル使用メモリ | 223,004 KB |
実行使用メモリ | 14,788 KB |
最終ジャッジ日時 | 2025-04-06 15:44:03 |
合計ジャッジ時間 | 6,508 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 30 |
ソースコード
#include <bits/stdc++.h> using namespace std; #include <atcoder/modint> using mint = atcoder::modint998244353; template<typename T> T gcd(T a, T b) { return b == 0 ? a : gcd(a % b, a); } namespace combination { template<typename mint> struct C { static vector<mint> fac, finv; static void init(int n) { int sz = fac.size(); if (n < sz) return; n = clamp(n, 2 * sz, min(1 << 25, mint::mod() - 1)); fac.resize(n + 1); finv.resize(n + 1); for (int i = sz; i <= n; i++) { fac[i] = i * fac[i - 1]; } finv[n] = fac[n].inv(); for (int i = n; i >= sz; i--) { finv[i - 1] = i * finv[i]; } } }; template<typename mint> vector<mint> C<mint>::fac(1, 1); template<typename mint> vector<mint> C<mint>::finv(1, 1); template<typename mint> mint fac(int n) { C<mint>::init(n); if (n < 0) return 0; return C<mint>::fac[n]; } template<typename mint> mint finv(int n) { C<mint>::init(n); if (n < 0) return 0; return C<mint>::finv[n]; } template<typename mint> mint mod_inv(int n) { assert(n > 0); return finv<mint>(n) * fac<mint>(n - 1); } template<typename mint> mint nCk(int n, int k) { if (n < 0 || n < k || k < 0) return 0; return fac<mint>(n) * finv<mint>(n - k) * finv<mint>(k); } template<typename mint> mint multi_C(const vector<int> &v) { int n = 0; for (const int &k : v) n += k; mint res = fac<mint>(n); for (const int &k : v) res *= finv<mint>(k); return res; } template<typename mint> mint nPk(int n, int k) { if (n < 0 || n < k || k < 0) return 0; return fac<mint>(n) * finv<mint>(n - k); } template<typename mint> mint catalan(int n) { return fac<mint>(2 * n) * finv<mint>(n) * finv<mint>(n + 1); } template<typename mint> mint grid_path(int n, int m) { return nCk<mint>(n + m, n); } } // namespace combination struct montgomery_modint { using int64 = uint64_t; using int128 = __uint128_t; using modint = montgomery_modint; montgomery_modint() : x(0) {} montgomery_modint(long long v) : x(reduce((int128(v) + MOD) * R)) {} static void set_mod(long long _m) { MOD = _m; R = -int128(MOD) % MOD; INV = get_inv_mod(); } static long long mod() { return MOD; } long long val() const { int64 res = reduce(x); return res >= MOD ? res - MOD : res; } modint& operator+=(const modint &r) { x += r.x; if (x >= (MOD << 1)) x -= (MOD << 1); return *this; } modint& operator-=(const modint &r) { x += (MOD << 1) - r.x; if (x >= (MOD << 1)) x -= (MOD << 1); return *this; } modint& operator*=(const modint &r) { x = reduce(int128(x) * r.x); return *this; } modint& operator/=(const modint &r) { *this *= r.inv(); return *this; } friend modint operator+(const modint &a, const modint &b) { return modint(a) += b; } friend modint operator-(const modint &a, const modint &b) { return modint(a) -= b; } friend modint operator*(const modint &a, const modint &b) { return modint(a) *= b; } friend modint operator/(const modint &a, const modint &b) { return modint(a) /= b; } friend bool operator==(const modint &a, const modint &b) { return a.val() == b.val(); } friend bool operator!=(const modint &a, const modint &b) { return a.val() != b.val(); } modint operator+() const { return *this; } modint operator-() const { return modint() - *this; } modint inv() const { return pow(MOD - 2); } modint pow(int128 k) const { modint a = *this; modint res = 1; while (k > 0) { if (k & 1) res *= a; a *= a; k >>= 1; } return res; } private: int64 x; static int64 MOD, INV, R; static int64 get_inv_mod() { int64 res = MOD; for (int t = 0; t < 5; t++) res *= 2 - MOD * res; return res; } static int64 reduce(const int128 &v) { return (v + int128(int64(v) * int64(-INV)) * MOD) >> 64; } }; typename montgomery_modint::int64 montgomery_modint::MOD, montgomery_modint::INV, montgomery_modint::R; bool miller_rabin(long long m, const vector<long long> ps) { using mint = montgomery_modint; mint::set_mod(m); long long u = 0, v = m - 1; while ((v & 1) == 0) u++, v >>= 1; for (long long p : ps) { if (m <= p) return true; mint x = mint(p).pow(v); if (x != 1) { long long w; for (w = 0; w < u; w++) { if (x == m - 1) break; x *= x; } if (u == w) return false; } } return true; } bool miller_rabin_small(long long m) { return miller_rabin(m, {2, 7, 61}); } bool miller_rabin_large(long long m) { return miller_rabin(m, {2, 325, 9375, 28178, 450775, 9780504, 1795265022}); } bool is_prime(long long m) { if (m <= 1) return false; if (m == 2) return true; if (m % 2 == 0) return false; return m < 4759123141LL ? miller_rabin_small(m) : miller_rabin_large(m); } random_device seed; mt19937 rng(seed()); mt19937_64 rng_64(seed()); int randint(int low, int hi) { assert(low <= hi); uniform_int_distribution<int> dist(low, hi); return dist(rng); } long long randint_64(long long low, long long hi) { assert(low <= hi); uniform_int_distribution<long long> dist(low, hi); return dist(rng_64); } double randdouble(double low, double hi) { uniform_real_distribution<double> dist(low, hi); return dist(rng); } pair<int, int> randpair(int low, int hi, bool strict = false) { assert(low + strict <= hi); int L = randint(low, hi - strict); int R = randint(L + strict, hi); return make_pair(L, R); } pair<long long, long long> randpair_64(long long low, long long hi, bool strict = false) { assert(low + strict <= hi); long long L = randint_64(low, hi - strict); long long R = randint_64(L + strict, hi); return make_pair(L, R); } template<typename T> T rho(T n) { for (int p : {2, 3, 5, 7}) { if (n % p == 0) return p; } using mint = montgomery_modint; mint::set_mod(n); while (true) { mint u = randint_64(2, n - 1); mint v = u; mint c = randint_64(1, n - 1); T d = 1; while (d == 1) { u = u * u + c; v = v * v + c; v = v * v + c; d = gcd((u - v).val(), n); } if (d < n) return d; } return -1; } template<typename T> vector<T> prime_factor(T n) { if (n <= 1) return {}; if (is_prime(n)) return {n}; vector<T> res; T d = rho(n); auto a = prime_factor(d); auto b = prime_factor(n / d); merge(a.begin(), a.end(), b.begin(), b.end(), back_inserter(res)); res.erase(unique(res.begin(), res.end()), res.end()); return res; } long long primitive_root(long long m) { if (m == 2) return 1; if (m == 167772161) return 3; if (m == 469762049) return 3; if (m == 754974721) return 11; if (m == 998244353) return 3; if (m == 1224736769) return 3; auto ps = prime_factor(m - 1); using mint = montgomery_modint; mint::set_mod(m); mint a = randint_64(1, m - 1); while ([&]{ for (auto p : ps) { if (a.pow((m - 1) / p) == 1) return true; } return false; }()) a = randint_64(1, m - 1); return a.val(); } template<typename mint> struct Number_Theoretic_Transform { static vector<mint> dw, dw_inv; static int log; static mint root; static void ntt(vector<mint> &f) { init(); int n = f.size(); for (int m = n; m >>= 1;) { mint w = 1; for (int s = 0, k = 0; s < n; s += (m << 1)) { for (int i = s, j = s + m; i < s + m; i++, j++) { mint a = f[i], b = f[j] * w; f[i] = a + b; f[j] = a - b; } w *= dw[__builtin_ctz(++k)]; } } } static void intt(vector<mint> &f) { init(); int n = f.size(); for (int m = 1; m < n; m <<= 1) { mint w = 1; for (int s = 0, k = 0; s < n; s += (m << 1)) { for (int i = s, j = s + m; i < s + m; i++, j++) { mint a = f[i], b = f[j]; f[i] = a + b; f[j] = (a - b) * w; } w *= dw_inv[__builtin_ctz(++k)]; } } mint invn = mint(n).inv(); for (mint &x : f) x *= invn; } private: Number_Theoretic_Transform() = default; static void init() { if (log > 0) return; int mod = mint::mod(); root = primitive_root(mod); int tmp = mod - 1; log = 1; while (tmp % 2 == 0) { tmp >>= 1; log++; } dw.resize(log); dw_inv.resize(log); for (int i = 0; i < log; i++) { dw[i] = -root.pow((mod - 1) >> (i + 2)); dw_inv[i] = dw[i].inv(); } } }; template<typename mint> vector<mint>Number_Theoretic_Transform<mint>::dw = vector<mint>(); template<typename mint> vector<mint>Number_Theoretic_Transform<mint>::dw_inv = vector<mint>(); template<typename mint> int Number_Theoretic_Transform<mint>::log = 0; template<typename mint> mint Number_Theoretic_Transform<mint>::root = -1; template<typename mint> struct Formal_Power_Series : vector<mint> { using FPS = Formal_Power_Series; using vector<mint>::vector; using NTT = Number_Theoretic_Transform<mint>; void ntt() { NTT::ntt(*this); } void intt() { NTT::intt(*this); } FPS &operator+=(const mint &r) { if (this->empty()) this->resize(1); (*this)[0] += r; return *this; } FPS &operator-=(const mint &r) { if (this->empty()) this->resize(1); (*this)[0] -= r; return *this; } FPS &operator*=(const mint &r) { for (mint &x : *this) x *= r; return *this; } FPS &operator/=(const mint &r) { mint invr = r.inv(); return *this *= invr; } FPS &operator+=(const FPS &f) { int n = this->size(), m = f.size(); if (n < m) this->resize(m); for (int i = 0; i < m; i++) (*this)[i] += f[i]; return *this; } FPS &operator-=(const FPS &f) { int n = this->size(), m = f.size(); if (n < m) this->resize(m); for (int i = 0; i < m; i++) (*this)[i] -= f[i]; return *this; } FPS &operator*=(const FPS &f) { *this = convolution(*this, f); return *this; } FPS &operator/=(const FPS &f) { return *this *= f.inv(); } FPS &operator%=(const FPS &f) { *this -= div(f) * f; this->shrink(); return *this; } FPS div(const FPS &f) const { if (this->size() < f.size()) return FPS{}; int n = this->size() - f.size() + 1; return (rev().pre(n) * f.rev().inv(n)).pre(n).rev(n); } FPS operator+(const mint &r) const { return FPS(*this) += r; } FPS operator-(const mint &r) const { return FPS(*this) -= r; } FPS operator*(const mint &r) const { return FPS(*this) *= r; } FPS operator/(const mint &r) const { return FPS(*this) /= r; } FPS operator+(const FPS &f) const { return FPS(*this) += f; } FPS operator-(const FPS &f) const { return FPS(*this) -= f; } FPS operator*(const FPS &f) const { return FPS(*this) *= f; } FPS operator/(const FPS &f) const { return FPS(*this) /= f; } FPS operator%(const FPS &f) const { return FPS(*this) %= f; } FPS operator-() const { return FPS{} - *this; } FPS operator<<(int n) const { FPS res(*this); res.insert(res.begin(), n, mint()); return res; } FPS operator>>(int n) const { if (int(this->size()) <= n) return FPS{}; FPS res(*this); res.erase(res.begin(), res.begin() + n); return res; } FPS &operator<<=(int n) { return *this = (*this) << n; } FPS &operator>>=(int n) { return *this = (*this) >> n; } FPS pre(int n) const { n = min(n, int(this->size())); return FPS(this->begin(), this->begin() + n); } FPS rev(int deg = -1) const { FPS res(*this); if (deg != -1) res.resize(deg, 0); reverse(res.begin(), res.end()); return res; } FPS dot(const FPS &f) const { int n = min(this->size(), f.size()); FPS res(n); for (int i = 0; i < n; i++) res[i] = (*this)[i] * f[i]; return res; } void shrink() { while (this->size() && this->back() == 0) { this->pop_back(); } } mint operator()(const mint &x) const { mint res = 0, powx = 1; for (const mint &a : *this) { res += a * powx; powx *= x; } return res; } FPS diff() const { int n = this->size(); if (n == 0) return FPS{}; FPS res(n - 1); for (int i = 1; i < n; i++) { res[i - 1] = i * (*this)[i]; } return res; } FPS integral() const { int n = this->size(); FPS res(n + 1); res[0] = 0; for (int i = 0; i < n; i++) { res[i + 1] = (*this)[i] * combination::mod_inv<mint>(i + 1); } return res; } FPS inv(int deg = -1) const { int n = this->size(); assert(n > 0); mint c = (*this)[0]; assert(c != 0); if (deg == -1) deg = n; FPS res(deg); res[0] = c.inv(); for (int d = 1; d < deg; d <<= 1) { FPS f(d << 1), g(d << 1); for (int i = 0; i < n && i < d << 1; i++) f[i] = (*this)[i]; for (int i = 0; i < d; i++) g[i] = res[i]; f.ntt(); g.ntt(); f = f.dot(g); f.intt(); for (int i = 0; i < d; i++) f[i] = 0; f.ntt(); f = f.dot(g); f.intt(); for (int i = d; i < deg && i < d << 1; i++) res[i] -= f[i]; } return res; } FPS exp(int deg = -1) const { int n = this->size(); if (deg == -1) deg = n; if (n == 0) { FPS res(deg); res[0] = 1; return res; } assert((*this)[0] == 0); auto inplace_diff = [](FPS &f) -> void { if (f.empty()) return; f.erase(f.begin()); for (int i = 0; i < int(f.size()); i++) f[i] *= i + 1; }; auto inplace_integral = [&](FPS &f) -> void { f.insert(f.begin(), 0); for (int i = 1; i < int(f.size()); i++) f[i] *= combination::mod_inv<mint>(i); }; FPS b = {1, 1 < n ? (*this)[1] : 0}; FPS c = {1}, z1, z2 = {1, 1}; for (int d = 2; d < deg; d <<= 1) { FPS y = b; y.resize(d << 1); y.ntt(); z1 = z2; FPS z = y.dot(z1); z.intt(); fill(z.begin(), z.begin() + (d >> 1), 0); z.ntt(); z = z.dot(-z1); z.intt(); c.insert(c.end(), z.begin() + (d >> 1), z.end()); z2 = c; z2.resize(d << 1); z2.ntt(); FPS x(this->begin(), this->begin() + min(n, d)); inplace_diff(x); x.push_back(0); x.ntt(); x = x.dot(y); x.intt(); x -= b.diff(); x.resize(d << 1); for (int i = 0; i < d - 1; i++) x[i + d] = x[i], x[i] = 0; x.ntt(); x = x.dot(z2); x.intt(); x.pop_back(); inplace_integral(x); for (int i = d; i < min(n, d << 1); i++) x[i] += (*this)[i]; fill(x.begin(), x.begin() + d, 0); x.ntt(); x = x.dot(y); x.intt(); b.insert(b.end(), x.begin() + d, x.end()); } return FPS(b.begin(), b.begin() + deg); } FPS log(int deg = -1) const { assert((*this)[0] == 1); if (deg == -1) deg = this->size(); return (diff() * inv()).integral().pre(deg); } FPS pow(long long k, int deg = -1) const { if (deg == -1) deg = this->size(); if (k == 0) { FPS res(deg); res[0] = 1; return res; } FPS res(*this); int p = 0; while (p < int(res.size()) && res[p] == 0) p++; if (p > (deg - 1) / k) return FPS(deg); res >>= p; deg -= p * k; mint c = res[0]; res = ((res / c).log(deg) * k).exp(deg) * c.pow(k); res <<= p * k; return res; } FPS taylor_shift(mint c) const { int n = this->size(); FPS f = *this; for (int i = 0; i < n; i++) f[i] *= combination::fac<mint>(i); reverse(f.begin(), f.end()); FPS g(n); mint pw = 1; for (int i = 0; i < n; i++) { g[i] = pw * combination::finv<mint>(i); pw *= c; } f = convolution(f, g, n); reverse(f.begin(), f.end()); for (int i = 0; i < n; i++) f[i] *= combination::finv<mint>(i); return f; } private: static FPS convolution(FPS f, FPS g, int deg = -1) { int n = f.size(), m = g.size(); if (n == 0 || m == 0) return FPS{}; int sz = 1; while (sz < n + m - 1) sz <<= 1; f.resize(sz); f.ntt(); g.resize(sz); g.ntt(); f = f.dot(g); f.intt(); if (deg == -1) deg = n + m - 1; f.resize(deg); return f; } }; template<typename mint> vector<mint> power_sum(int k, long long n) { using FPS = Formal_Power_Series<mint>; vector<mint> fac(k + 2), finv(k + 2); fac[0] = 1; for (int i = 1; i <= k + 1; i++) { fac[i] = i * fac[i - 1]; } finv[k + 1] = fac[k + 1].inv(); for (int i = k + 1; i >= 1; i--) { finv[i - 1] = i * finv[i]; } mint pown = n + 1; FPS f(k + 1), g(k + 1); for (int i = 0; i <= k; i++) { f[i] = pown * finv[i + 1]; g[i] = finv[i + 1]; pown *= n + 1; } f /= g; vector<mint> res(k + 1); res[0] = n + 1; for (int i = 1; i <= k; i++) { res[i] = f[i] * fac[i]; } return res; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int N, M; cin >> N >> M; auto P = power_sum<mint>(N, M); mint ans = 0; for (int i = 1; i <= N; i++) { ans += P[i] * mint(M).pow(N - i); } cout << ans.val() << endl; }