結果
問題 | No.802 だいたい等差数列 |
ユーザー |
![]() |
提出日時 | 2019-03-17 21:54:15 |
言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 5,906 bytes |
コンパイル時間 | 1,459 ms |
コンパイル使用メモリ | 88,696 KB |
実行使用メモリ | 92,768 KB |
最終ジャッジ日時 | 2024-07-07 22:18:44 |
合計ジャッジ時間 | 5,415 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | -- * 4 |
other | AC * 2 WA * 8 TLE * 1 -- * 19 |
ソースコード
#ifndef ___CLASS_MODINT#define ___CLASS_MODINT#include <vector>#include <cstdint>using singlebit = uint32_t;using doublebit = uint64_t;static constexpr singlebit find_inv(singlebit n, int d = 5, singlebit x = 1) {return d == 0 ? x : find_inv(n, d - 1, x * (2 - x * n));}template <singlebit mod, singlebit primroot> class modint {// Fast Modulo Integer, Assertion: mod < 2^31private:singlebit n;static constexpr int level = 32; // LIMIT OF singlebitstatic constexpr singlebit max_value = -1;static constexpr singlebit r2 = (((1ull << level) % mod) << level) % mod;static constexpr singlebit inv = singlebit(-1) * find_inv(mod);static singlebit reduce(doublebit x) {singlebit res = (x + doublebit(singlebit(x) * inv) * mod) >> level;return res < mod ? res : res - mod;}public:modint() : n(0) {};modint(singlebit n_) { n = reduce(doublebit(n_) * r2); };modint& operator=(const singlebit x) { n = reduce(doublebit(x) * r2); return *this; }bool operator==(const modint& x) const { return n == x.n; }bool operator!=(const modint& x) const { return n != x.n; }modint& operator+=(const modint& x) { n += x.n; n -= (n < mod ? 0 : mod); return *this; }modint& operator-=(const modint& x) { n += mod - x.n; n -= (n < mod ? 0 : mod); return *this; }modint& operator*=(const modint& x) { n = reduce(1ull * n * x.n); return *this; }modint operator+(const modint& x) const { return modint(*this) += x; }modint operator-(const modint& x) const { return modint(*this) -= x; }modint operator*(const modint& x) const { return modint(*this) *= x; }static singlebit get_mod() { return mod; }static singlebit get_primroot() { return primroot; }singlebit get() { return reduce(doublebit(n)); }modint binpow(singlebit b) {modint ans(1), cur(*this);while (b > 0) {if (b & 1) ans *= cur;cur *= cur;b >>= 1;}return ans;}};template<typename modulo>std::vector<modulo> get_modvector(std::vector<int> v) {std::vector<modulo> ans(v.size());for (int i = 0; i < v.size(); ++i) {ans[i] = v[i];}return ans;}#endif#ifndef ___CLASS_NTT#define ___CLASS_NTT#include <vector>template<typename modulo>class ntt {// Number Theoretic Transformprivate:int depth;std::vector<modulo> roots;std::vector<modulo> powinv;public:ntt() {depth = 0;uint32_t div_number = modulo::get_mod() - 1;while (div_number % 2 == 0) div_number >>= 1, ++depth;modulo b = modulo::get_primroot();for (int i = 0; i < depth; ++i) b *= b;modulo baseroot = modulo::get_primroot(), bb = b;while (bb != 1) bb *= b, baseroot *= modulo::get_primroot();roots = std::vector<modulo>(depth + 1, 0);powinv = std::vector<modulo>(depth + 1, 0);powinv[1] = (modulo::get_mod() + 1) / 2;for (int i = 2; i <= depth; ++i) powinv[i] = powinv[i - 1] * powinv[1];roots[depth] = 1;for (int i = 0; i < modulo::get_mod() - 1; i += 1 << depth) roots[depth] *= baseroot;for (int i = depth - 1; i >= 1; --i) roots[i] = roots[i + 1] * roots[i + 1];}void fourier_transform(std::vector<modulo> &v, bool inverse) {int s = v.size();for (int i = 0, j = 1; j < s - 1; ++j) {for (int k = s >> 1; k >(i ^= k); k >>= 1);if (i < j) std::swap(v[i], v[j]);}int sc = 0, sz = 1;while (sz < s) sz *= 2, ++sc;std::vector<modulo> pw(s + 1); pw[0] = 1;for (int i = 1; i <= s; i++) pw[i] = pw[i - 1] * roots[sc];int qs = s;for (int b = 1; b < s; b <<= 1) {qs >>= 1;for (int i = 0; i < s; i += b * 2) {for (int j = i; j < i + b; ++j) {modulo delta = pw[(inverse ? b * 2 - j + i : j - i) * qs] * v[j + b];v[j + b] = v[j] - delta;v[j] += delta;}}}if (inverse) {for (int i = 0; i < s; ++i) v[i] *= powinv[sc];}}std::vector<modulo> convolve(std::vector<modulo> v1, std::vector<modulo> v2) {const int threshold = 16;if (v1.size() < v2.size()) swap(v1, v2);int s1 = 1; while (s1 < v1.size()) s1 <<= 1; v1.resize(s1);int s2 = 1; while (s2 < v2.size()) s2 <<= 1; v2.resize(s2 * 2);std::vector<modulo> ans(s1 + s2);if (s2 <= threshold) {for (int i = 0; i < s1; ++i) {for (int j = 0; j < s2; ++j) {ans[i + j] += v1[i] * v2[j];}}}else {fourier_transform(v2, false);for (int i = 0; i < s1; i += s2) {std::vector<modulo> v(v1.begin() + i, v1.begin() + i + s2);v.resize(s2 * 2);fourier_transform(v, false);for (int j = 0; j < v.size(); ++j) v[j] *= v2[j];fourier_transform(v, true);for (int j = 0; j < s2 * 2; ++j) {ans[i + j] += v[j];}}}return ans;}};#endif#include <vector>#include <iostream>using namespace std;using modulo1 = modint<469762049, 3>; ntt<modulo1> ntt_base1;using modulo2 = modint<167772161, 3>; ntt<modulo2> ntt_base2;const modulo1 magic_inv = modulo1(modulo2::get_mod()).binpow(modulo1::get_mod() - 2);const int mod = 1000000007;vector<int> convolve_mod(vector<int> v1, vector<int> v2) {vector<modulo1> mul_base1 = ntt_base1.convolve(get_modvector<modulo1>(v1), get_modvector<modulo1>(v2));vector<modulo2> mul_base2 = ntt_base2.convolve(get_modvector<modulo2>(v1), get_modvector<modulo2>(v2));vector<int> ans(mul_base1.size());for (int i = 0; i < mul_base1.size(); ++i) {long long val = (long long)(((mul_base1[i] - modulo1(mul_base2[i].get())) * magic_inv).get()) * modulo2::get_mod() + mul_base2[i].get();ans[i] = val % mod;}return ans;}int main() {int N, M, D1, D2;cin >> N >> M >> D1 >> D2;vector<int> cur(M); cur[0] = 1;vector<int> pw(M);for (int i = D1; i <= D2; ++i) {if (0 <= i && i < M) pw[i] = 1;}--N;while (N) {if (N & 1) {cur = convolve_mod(cur, pw);cur.resize(M);}pw = convolve_mod(pw, pw);pw.resize(M);N >>= 1;}int ans = 0;for (int i = 0; i < M; ++i) {ans = (ans + (long long)(cur[i]) * (M - i)) % mod;}cout << ans << endl;return 0;}