結果

問題 No.802 だいたい等差数列
ユーザー square1001square1001
提出日時 2019-03-17 22:03:00
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 7,035 bytes
コンパイル時間 2,726 ms
コンパイル使用メモリ 103,780 KB
実行使用メモリ 4,380 KB
最終ジャッジ日時 2023-09-22 06:11:49
合計ジャッジ時間 8,734 ms
ジャッジサーバーID
(参考情報)
judge13 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 348 ms
4,376 KB
testcase_01 AC 109 ms
4,376 KB
testcase_02 AC 394 ms
4,376 KB
testcase_03 AC 87 ms
4,376 KB
testcase_04 AC 185 ms
4,376 KB
testcase_05 AC 264 ms
4,376 KB
testcase_06 AC 183 ms
4,376 KB
testcase_07 AC 76 ms
4,376 KB
testcase_08 AC 328 ms
4,380 KB
testcase_09 AC 186 ms
4,376 KB
testcase_10 TLE -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

#ifndef ___CLASS_MODINT
#define ___CLASS_MODINT

#include <vector>
#include <cstdint>

using singlebit = uint32_t;
using doublebit = uint64_t;

static constexpr singlebit find_inv(singlebit n, int d = 5, singlebit x = 1) {
	return d == 0 ? x : find_inv(n, d - 1, x * (2 - x * n));
}
template <singlebit mod, singlebit primroot> class modint {
	// Fast Modulo Integer, Assertion: mod < 2^31
private:
	singlebit n;
	static constexpr int level = 32; // LIMIT OF singlebit
	static constexpr singlebit max_value = -1;
	static constexpr singlebit r2 = (((1ull << level) % mod) << level) % mod;
	static constexpr singlebit inv = singlebit(-1) * find_inv(mod);
	static singlebit reduce(doublebit x) {
		singlebit res = (x + doublebit(singlebit(x) * inv) * mod) >> level;
		return res < mod ? res : res - mod;
	}
public:
	modint() : n(0) {};
	modint(singlebit n_) { n = reduce(doublebit(n_) * r2); };
	modint& operator=(const singlebit x) { n = reduce(doublebit(x) * r2); return *this; }
	bool operator==(const modint& x) const { return n == x.n; }
	bool operator!=(const modint& x) const { return n != x.n; }
	modint& operator+=(const modint& x) { n += x.n; n -= (n < mod ? 0 : mod); return *this; }
	modint& operator-=(const modint& x) { n += mod - x.n; n -= (n < mod ? 0 : mod); return *this; }
	modint& operator*=(const modint& x) { n = reduce(1ull * n * x.n); return *this; }
	modint operator+(const modint& x) const { return modint(*this) += x; }
	modint operator-(const modint& x) const { return modint(*this) -= x; }
	modint operator*(const modint& x) const { return modint(*this) *= x; }
	static singlebit get_mod() { return mod; }
	static singlebit get_primroot() { return primroot; }
	singlebit get() { return reduce(doublebit(n)); }
	modint binpow(singlebit b) {
		modint ans(1), cur(*this);
		while (b > 0) {
			if (b & 1) ans *= cur;
			cur *= cur;
			b >>= 1;
		}
		return ans;
	}
};

template<typename modulo>
std::vector<modulo> get_modvector(std::vector<int> v) {
	std::vector<modulo> ans(v.size());
	for (int i = 0; i < v.size(); ++i) {
		ans[i] = v[i];
	}
	return ans;
}

#endif

#ifndef ___CLASS_NTT
#define ___CLASS_NTT

#include <vector>

template<typename modulo>
class ntt {
	// Number Theoretic Transform
private:
	int depth;
	std::vector<modulo> roots;
	std::vector<modulo> powinv;
public:
	ntt() {
		depth = 0;
		uint32_t div_number = modulo::get_mod() - 1;
		while (div_number % 2 == 0) div_number >>= 1, ++depth;
		modulo b = modulo::get_primroot();
		for (int i = 0; i < depth; ++i) b *= b;
		modulo baseroot = modulo::get_primroot(), bb = b;
		while (bb != 1) bb *= b, baseroot *= modulo::get_primroot();
		roots = std::vector<modulo>(depth + 1, 0);
		powinv = std::vector<modulo>(depth + 1, 0);
		powinv[1] = (modulo::get_mod() + 1) / 2;
		for (int i = 2; i <= depth; ++i) powinv[i] = powinv[i - 1] * powinv[1];
		roots[depth] = 1;
		for (int i = 0; i < modulo::get_mod() - 1; i += 1 << depth) roots[depth] *= baseroot;
		for (int i = depth - 1; i >= 1; --i) roots[i] = roots[i + 1] * roots[i + 1];
	}
	void fourier_transform(std::vector<modulo> &v, bool inverse) {
		int s = v.size();
		for (int i = 0, j = 1; j < s - 1; ++j) {
			for (int k = s >> 1; k >(i ^= k); k >>= 1);
			if (i < j) std::swap(v[i], v[j]);
		}
		int sc = 0, sz = 1;
		while (sz < s) sz *= 2, ++sc;
		std::vector<modulo> pw(s + 1); pw[0] = 1;
		for (int i = 1; i <= s; i++) pw[i] = pw[i - 1] * roots[sc];
		int qs = s;
		for (int b = 1; b < s; b <<= 1) {
			qs >>= 1;
			for (int i = 0; i < s; i += b * 2) {
				for (int j = i; j < i + b; ++j) {
					modulo delta = pw[(inverse ? b * 2 - j + i : j - i) * qs] * v[j + b];
					v[j + b] = v[j] - delta;
					v[j] += delta;
				}
			}
		}
		if (inverse) {
			for (int i = 0; i < s; ++i) v[i] *= powinv[sc];
		}
	}
	std::vector<modulo> convolve(std::vector<modulo> v1, std::vector<modulo> v2) {
		const int threshold = 16;
		if (v1.size() < v2.size()) swap(v1, v2);
		int s1 = 1; while (s1 < v1.size()) s1 <<= 1; v1.resize(s1);
		int s2 = 1; while (s2 < v2.size()) s2 <<= 1; v2.resize(s2 * 2);
		std::vector<modulo> ans(s1 + s2);
		if (s2 <= threshold) {
			for (int i = 0; i < s1; ++i) {
				for (int j = 0; j < s2; ++j) {
					ans[i + j] += v1[i] * v2[j];
				}
			}
		}
		else {
			fourier_transform(v2, false);
			for (int i = 0; i < s1; i += s2) {
				std::vector<modulo> v(v1.begin() + i, v1.begin() + i + s2);
				v.resize(s2 * 2);
				fourier_transform(v, false);
				for (int j = 0; j < v.size(); ++j) v[j] *= v2[j];
				fourier_transform(v, true);
				for (int j = 0; j < s2 * 2; ++j) {
					ans[i + j] += v[j];
				}
			}
		}
		return ans;
	}
};

#endif

#include <vector>
#include <iostream>
using namespace std;

using modulo1 = modint<469762049, 3>; ntt<modulo1> ntt_base1;
using modulo2 = modint<167772161, 3>; ntt<modulo2> ntt_base2;
using modulo3 = modint<998244353, 3>; ntt<modulo3> ntt_base3;

const modulo1 magic_inv = modulo1(modulo2::get_mod()).binpow(modulo1::get_mod() - 2);

const int mod = 1000000007;

int binpow(int a, int b, int p) {
	int ans = 1;
	while (b) {
		if (b & 1) ans = (long long)(ans) * a % p;
		a = (long long)(a) * a % p;
		b >>= 1;
	}
	return ans;
}

// Garner: Thanks to https://math314.hateblo.jp/entry/2015/05/07/014908
long long garner(vector<pair<int, int> > mr, int mod) {
	mr.emplace_back(mod, 0);

	vector<long long> coffs(mr.size(), 1);
	vector<long long> constants(mr.size(), 0);
	for (int i = 0; i < mr.size() - 1; ++i) {
		// coffs[i] * v + constants[i] == mr[i].second (mod mr[i].first) を解く
		long long v = (mr[i].second - constants[i]) * binpow(coffs[i] % mr[i].first, mr[i].first - 2, mr[i].first) % mr[i].first;
		if (v < 0) v += mr[i].first;

		for (int j = i + 1; j < mr.size(); j++) {
			(constants[j] += coffs[j] * v) %= mr[j].first;
			(coffs[j] *= mr[i].first) %= mr[j].first;
		}
	}

	return constants[mr.size() - 1];
}

vector<int> convolve_mod(vector<int> v1, vector<int> v2) {
	vector<modulo1> mul_base1 = ntt_base1.convolve(get_modvector<modulo1>(v1), get_modvector<modulo1>(v2));
	vector<modulo2> mul_base2 = ntt_base2.convolve(get_modvector<modulo2>(v1), get_modvector<modulo2>(v2));
	vector<modulo3> mul_base3 = ntt_base3.convolve(get_modvector<modulo3>(v1), get_modvector<modulo3>(v2));
	vector<int> ans(mul_base1.size());
	for (int i = 0; i < mul_base1.size(); ++i) {
		vector<pair<int, int> > vec = {
			make_pair(modulo1::get_mod(), mul_base1[i].get()),
			make_pair(modulo2::get_mod(), mul_base2[i].get()),
			make_pair(modulo3::get_mod(), mul_base3[i].get())
		};
		long long val = garner(vec, mod);
		ans[i] = val % mod;
	}
	return ans;
}

int main() {
	int N, M, D1, D2;
	cin >> N >> M >> D1 >> D2;
	vector<int> cur(M); cur[0] = 1;
	vector<int> pw(M);
	for (int i = D1; i <= D2; ++i) {
		if (0 <= i && i < M) pw[i] = 1;
	}
	--N;
	while (N) {
		if (N & 1) {
			cur = convolve_mod(cur, pw);
			cur.resize(M);
		}
		pw = convolve_mod(pw, pw);
		pw.resize(M);
		N >>= 1;
	}
	int ans = 0;
	for (int i = 0; i < M; ++i) {
		ans = (ans + (long long)(cur[i]) * (M - i)) % mod;
	}
	cout << ans << endl;
	return 0;
}
0