結果
問題 | No.2670 Sum of Products of Interval Lengths |
ユーザー | Misuki |
提出日時 | 2024-03-21 22:52:59 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 172 ms / 2,000 ms |
コード長 | 13,483 bytes |
コンパイル時間 | 3,193 ms |
コンパイル使用メモリ | 214,268 KB |
実行使用メモリ | 22,804 KB |
最終ジャッジ日時 | 2024-09-30 10:23:22 |
合計ジャッジ時間 | 5,511 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 17 |
ソースコード
#pragma GCC optimize("O2")#include <algorithm>#include <array>#include <bitset>#include <cassert>#include <cctype>#include <cfenv>#include <cfloat>#include <chrono>#include <cinttypes>#include <climits>#include <cmath>#include <complex>#include <cstdarg>#include <cstddef>#include <cstdint>#include <cstdio>#include <cstdlib>#include <cstring>#include <deque>#include <fstream>#include <functional>#include <initializer_list>#include <iomanip>#include <ios>#include <iostream>#include <istream>#include <iterator>#include <limits>#include <list>#include <map>#include <memory>#include <new>#include <numeric>#include <ostream>#include <queue>#include <random>#include <set>#include <sstream>#include <stack>#include <streambuf>#include <string>#include <tuple>#include <type_traits>#include <variant>#if __cplusplus >= 202002L#include <bit>#include <compare>#include <concepts>#include <numbers>#include <ranges>#include <span>#else#define ssize(v) (int)(v).size()#define popcount(x) __builtin_popcountll(x)constexpr int bit_width(const unsigned int x) { return x == 0 ? 0 : ((sizeof(unsigned int) * CHAR_BIT) - __builtin_clz(x)); }constexpr int bit_width(const unsigned long long x) { return x == 0 ? 0 : ((sizeof(unsigned long long) * CHAR_BIT) - __builtin_clzll(x)); }constexpr int countr_zero(const unsigned int x) { return x == 0 ? sizeof(unsigned int) * CHAR_BIT : __builtin_ctz(x); }constexpr int countr_zero(const unsigned long long x) { return x == 0 ? sizeof(unsigned long long) * CHAR_BIT : __builtin_ctzll(x); }constexpr unsigned int bit_ceil(const unsigned int x) { return x == 0 ? 1 : (popcount(x) == 1 ? x : (1u << bit_width(x))); }constexpr unsigned long long bit_ceil(const unsigned long long x) { return x == 0 ? 1 : (popcount(x) == 1 ? x : (1ull << bit_width(x))); }#endif//#define int ll#define INT128_MAX (__int128)(((unsigned __int128) 1 << ((sizeof(__int128) * __CHAR_BIT__) - 1)) - 1)#define INT128_MIN (-INT128_MAX - 1)#define clock chrono::steady_clock::now().time_since_epoch().count()#ifdef DEBUG#define dbg(x) cout << (#x) << " = " << x << '\n'#else#define dbg(x)#endifusing namespace std;using ll = long long;using ull = unsigned long long;using ldb = long double;using pii = pair<int, int>;using pll = pair<ll, ll>;//#define double ldbtemplate<class T>ostream& operator<<(ostream& os, const pair<T, T> pr) {return os << pr.first << ' ' << pr.second;}template<class T, size_t N>ostream& operator<<(ostream& os, const array<T, N> &arr) {for(const T &X : arr)os << X << ' ';return os;}template<class T>ostream& operator<<(ostream& os, const vector<T> &vec) {for(const T &X : vec)os << X << ' ';return os;}template<class T>ostream& operator<<(ostream& os, const set<T> &s) {for(const T &x : s)os << x << ' ';return os;}//reference: https://github.com/NyaanNyaan/library/blob/master/modint/montgomery-modint.hpp#L10//note: mod should be a prime less than 2^30.template<uint32_t mod>struct MontgomeryModInt {using mint = MontgomeryModInt;using i32 = int32_t;using u32 = uint32_t;using u64 = uint64_t;static constexpr u32 get_r() {u32 res = 1, base = mod;for(i32 i = 0; i < 31; i++)res *= base, base *= base;return -res;}static constexpr u32 get_mod() {return mod;}static constexpr u32 n2 = -u64(mod) % mod; //2^64 % modstatic constexpr u32 r = get_r(); //-P^{-1} % 2^32u32 a;static u32 reduce(const u64 &b) {return (b + u64(u32(b) * r) * mod) >> 32;}static u32 transform(const u64 &b) {return reduce(u64(b) * n2);}MontgomeryModInt() : a(0) {}MontgomeryModInt(const int64_t &b): a(transform(b % mod + mod)) {}mint pow(u64 k) const {mint res(1), base(*this);while(k) {if (k & 1)res *= base;base *= base, k >>= 1;}return res;}mint inverse() const { return (*this).pow(mod - 2); }u32 get() const {u32 res = reduce(a);return res >= mod ? res - mod : res;}mint& operator+=(const mint &b) {if (i32(a += b.a - 2 * mod) < 0) a += 2 * mod;return *this;}mint& operator-=(const mint &b) {if (i32(a -= b.a) < 0) a += 2 * mod;return *this;}mint& operator*=(const mint &b) {a = reduce(u64(a) * b.a);return *this;}mint& operator/=(const mint &b) {a = reduce(u64(a) * b.inverse().a);return *this;}mint operator-() { return mint() - mint(*this); }bool operator==(mint b) const {return (a >= mod ? a - mod : a) == (b.a >= mod ? b.a - mod : b.a);}bool operator!=(mint b) const {return (a >= mod ? a - mod : a) != (b.a >= mod ? b.a - mod : b.a);}friend mint operator+(mint a, mint b) { return a += b; }friend mint operator-(mint a, mint b) { return a -= b; }friend mint operator*(mint a, mint b) { return a *= b; }friend mint operator/(mint a, mint b) { return a /= b; }friend ostream& operator<<(ostream& os, const mint& b) {return os << b.get();}friend istream& operator>>(istream& is, mint& b) {int64_t val;is >> val;b = mint(val);return is;}};using mint = MontgomeryModInt<998244353>;//reference: https://judge.yosupo.jp/submission/69896//remark: MOD = 2^K * C + 1, R is a primitive root modulo MOD//remark: a.size() <= 2^K must be satisfied//some common modulo: 998244353 = 2^23 * 119 + 1, R = 3// 469762049 = 2^26 * 7 + 1, R = 3// 1224736769 = 2^24 * 73 + 1, R = 3template<int32_t k = 23, int32_t c = 119, int32_t r = 3, class Mint = MontgomeryModInt<998244353>>struct NTT {using u32 = uint32_t;static constexpr u32 mod = (1 << k) * c + 1;static constexpr u32 get_mod() { return mod; }static void ntt(vector<Mint> &a, bool inverse) {static array<Mint, 30> w, w_inv;if (w[0] == 0) {Mint root = 2;while(root.pow((mod - 1) / 2) == 1) root += 1;for(int i = 0; i < 30; i++)w[i] = -(root.pow((mod - 1) >> (i + 2))), w_inv[i] = 1 / w[i];}int n = ssize(a);if (not inverse) {for(int m = n; m >>= 1; ) {Mint ww = 1;for(int s = 0, l = 0; s < n; s += 2 * m) {for(int i = s, j = s + m; i < s + m; i++, j++) {Mint x = a[i], y = a[j] * ww;a[i] = x + y, a[j] = x - y;}ww *= w[__builtin_ctz(++l)];}}} else {for(int m = 1; m < n; m *= 2) {Mint ww = 1;for(int s = 0, l = 0; s < n; s += 2 * m) {for(int i = s, j = s + m; i < s + m; i++, j++) {Mint x = a[i], y = a[j];a[i] = x + y, a[j] = (x - y) * ww;}ww *= w_inv[__builtin_ctz(++l)];}}Mint inv = 1 / Mint(n);for(Mint &x : a) x *= inv;}}static vector<Mint> conv(vector<Mint> a, vector<Mint> b) {int sz = ssize(a) + ssize(b) - 1;int n = bit_ceil((u32)sz);a.resize(n, 0);ntt(a, false);b.resize(n, 0);ntt(b, false);for(int i = 0; i < n; i++)a[i] *= b[i];ntt(a, true);a.resize(sz);return a;}};//#include "modint/MontgomeryModInt.cpp"//#include "poly/NTTmint.cpp"//lagrange inversion formula:// let f(x) be composition inverse of g(x) (i.e. f(g(x)) = x) and [x^0]f(x) = [x^0]g(x) = 0, [x^1]f(x) != 0, [x^1]g(x) != 0, then// [x^n]g(x)^k = k/n [x^{n - k}] (x / f(x))^n// [x^n]g(x) = 1/n [x^{n - 1}] (x / f(x))^n (for k = 1)template<class Mint>struct FPS : vector<Mint> {static function<vector<Mint>(vector<Mint>, vector<Mint>)> conv;FPS(vector<Mint> v) : vector<Mint>(v) {}using vector<Mint>::vector;FPS& operator+=(FPS b) {if (ssize(*this) < ssize(b)) this -> resize(ssize(b), 0);for(int i = 0; i < ssize(b); i++)(*this)[i] += b[i];return *this;}FPS& operator-=(FPS b) {if (ssize(*this) < ssize(b)) this -> resize(ssize(b), 0);for(int i = 0; i < ssize(b); i++)(*this)[i] -= b[i];return *this;}FPS& operator*=(FPS b) {auto c = conv(*this, b);this -> resize(ssize(*this) + ssize(b) - 1);copy(c.begin(), c.end(), this -> begin());return *this;}FPS& operator*=(Mint b) {for(int i = 0; i < ssize(*this); i++)(*this)[i] *= b;return *this;}FPS& operator/=(Mint b) {b = Mint(1) / b;for(int i = 0; i < ssize(*this); i++)(*this)[i] *= b;return *this;}FPS shrink() {FPS F = *this;int size = ssize(F);while(size and F[size - 1] == 0) size -= 1;F.resize(size);return F;}FPS integral() {if (this -> empty()) return {0};vector<Mint> Inv(ssize(*this) + 1);Inv[1] = 1;for(int i = 2; i < ssize(Inv); i++)Inv[i] = (Mint::get_mod() - Mint::get_mod() / i) * Inv[Mint::get_mod() % i];FPS Q(ssize(*this) + 1, 0);for(int i = 0; i < ssize(*this); i++)Q[i + 1] = (*this)[i] * Inv[i + 1];return Q;}FPS derivative() {assert(!this -> empty());FPS Q(ssize(*this) - 1);for(int i = 1; i < ssize(*this); i++)Q[i - 1] = (*this)[i] * i;return Q;}Mint eval(Mint x) {Mint base = 1, res = 0;for(int i = 0; i < ssize(*this); i++, base *= x)res += (*this)[i] * base;return res;}FPS inv(int k) { // 1 / FPS (mod x^k)assert(!this -> empty() and (*this)[0] != 0);FPS Q(1, 1 / (*this)[0]);for(int i = 1; (1 << (i - 1)) < k; i++) {FPS P = (*this);P.resize(1 << i, 0);Q = Q * (FPS(1, 2) - P * Q);Q.resize(1 << i, 0);}Q.resize(k);return Q;}array<FPS, 2> div(FPS G) {FPS F = this -> shrink();G = G.shrink();assert(!G.empty());if (ssize(G) > ssize(F))return {{{}, F}};int n = ssize(F) - ssize(G) + 1;auto FR = F, GR = G;ranges::reverse(FR);ranges::reverse(GR);FPS Q = FR * GR.inv(n);Q.resize(n);ranges::reverse(Q);return {Q, (F - G * Q).shrink()};}FPS log(int k) {assert(!this -> empty() and (*this)[0] == 1);FPS Q = *this;Q = (Q.derivative() * Q.inv(k));Q.resize(k - 1);return Q.integral();}FPS exp(int k) {assert(!this -> empty() and (*this)[0] == 0);FPS Q(1, 1);for(int i = 1; (1 << (i - 1)) < k; i++) {FPS P = (*this);P.resize(1 << i, 0);Q = Q * (FPS(1, 1) + P - Q.log(1 << i));Q.resize(1 << i, 0);}Q.resize(k);return Q;}FPS pow(ll idx, int k) {if (idx == 0) {FPS res(k, 0);res[0] = 1;return res;}for(int i = 0; i < ssize(*this) and i * idx < k; i++) {if ((*this)[i] != 0) {Mint Inv = 1 / (*this)[i];FPS Q(ssize(*this) - i);for(int j = i; j < ssize(*this); j++)Q[j - i] = (*this)[j] * Inv;Q = (Q.log(k) * idx).exp(k);FPS Q2(k, 0);Mint Pow = (*this)[i].pow(idx);for(int j = 0; j + i * idx < k; j++)Q2[j + i * idx] = Q[j] * Pow;return Q2;}}return FPS(k, 0);}vector<Mint> multieval(vector<Mint> xs) {int n = ssize(xs);vector<FPS> data(2 * n);for(int i = 0; i < n; i++)data[n + i] = {-xs[i], 1};for(int i = n - 1; i > 0; i--)data[i] = data[i << 1] * data[i << 1 | 1];data[1] = (this -> div(data[1]))[1];for(int i = 1; i < n; i++) {data[i << 1] = data[i].div(data[i << 1])[1];data[i << 1 | 1] = data[i].div(data[i << 1 | 1])[1];}vector<Mint> res(n);for(int i = 0; i < n; i++)res[i] = data[n + i].empty() ? 0 : data[n + i][0];return res;}static vector<Mint> interpolate(vector<Mint> xs, vector<Mint> ys) {assert(ssize(xs) == ssize(ys));int n = ssize(xs);vector<FPS> data(2 * n), res(2 * n);for(int i = 0; i < n; i++)data[n + i] = {-xs[i], 1};for(int i = n - 1; i > 0; i--)data[i] = data[i << 1] * data[i << 1 | 1];res[1] = data[1].derivative().div(data[1])[1];for(int i = 1; i < n; i++) {res[i << 1] = res[i].div(data[i << 1])[1];res[i << 1 | 1] = res[i].div(data[i << 1 | 1])[1];}for(int i = 0; i < n; i++)res[n + i][0] = ys[i] / res[n + i][0];for(int i = n - 1; i > 0; i--)res[i] = res[i << 1] * data[i << 1 | 1] + res[i << 1 | 1] * data[i << 1];return res[1];}static vector<Mint> allProd(vector<FPS> &fs) {if (fs.empty()) return {1};auto dfs = [&](int l, int r, auto self) -> FPS {if (l + 1 == r)return fs[l];elsereturn self(l, (l + r) / 2, self) * self((l + r) / 2, r, self);};return dfs(0, ssize(fs), dfs);}friend FPS operator+(FPS a, FPS b) { return a += b; }friend FPS operator-(FPS a, FPS b) { return a -= b; }friend FPS operator*(FPS a, FPS b) { return a *= b; }friend FPS operator*(FPS a, Mint b) { return a *= b; }friend FPS operator/(FPS a, Mint b) { return a /= b; }};NTT ntt;using fps = FPS<mint>;template<>function<vector<mint>(vector<mint>, vector<mint>)> fps::conv = ntt.conv;signed main() {ios::sync_with_stdio(false), cin.tie(NULL);int n; cin >> n;ll m; cin >> m;fps f(n + 1);f[1] = 1;for(int i = 2; i <= n; i++)f[i] = -f[i - 2] + f[i - 1];for(int i = 0; i <= n; i++)f[i] *= max(m + 1 - i, 0ll);f *= -1;f[0] = 1;cout << f.inv(n + 1).back() << '\n';return 0;}