結果

問題 No.1970 ひよこ鑑定士
ユーザー 57tggx57tggx
提出日時 2022-05-31 10:15:04
言語 C++23
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 702 ms / 2,000 ms
コード長 5,863 bytes
コンパイル時間 8,609 ms
コンパイル使用メモリ 276,004 KB
実行使用メモリ 13,784 KB
最終ジャッジ日時 2024-09-21 01:11:22
合計ジャッジ時間 14,299 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,812 KB
testcase_01 AC 2 ms
6,812 KB
testcase_02 AC 3 ms
6,940 KB
testcase_03 AC 2 ms
6,940 KB
testcase_04 AC 2 ms
6,940 KB
testcase_05 AC 2 ms
6,940 KB
testcase_06 AC 2 ms
6,940 KB
testcase_07 AC 16 ms
6,944 KB
testcase_08 AC 678 ms
13,784 KB
testcase_09 AC 245 ms
7,616 KB
testcase_10 AC 244 ms
7,700 KB
testcase_11 AC 51 ms
6,944 KB
testcase_12 AC 229 ms
7,672 KB
testcase_13 AC 478 ms
10,864 KB
testcase_14 AC 548 ms
11,292 KB
testcase_15 AC 13 ms
6,940 KB
testcase_16 AC 50 ms
6,940 KB
testcase_17 AC 471 ms
10,712 KB
testcase_18 AC 476 ms
10,624 KB
testcase_19 AC 523 ms
10,972 KB
testcase_20 AC 697 ms
13,784 KB
testcase_21 AC 702 ms
13,780 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdint>
#include <cstddef>
#include <utility>
#include <vector>
#include <iostream>
#include "testlib.h"

unsigned bit_width(unsigned long long x){ return x ? 8 * sizeof(x) - __builtin_clzll(x) : 0; }

template <std::uint32_t MOD>
class ModInt {
	std::uint32_t value;
	friend std::istream &operator>>(std::istream &is, ModInt &x){ is >> x.value; x.value %= MOD; return is; }
public:
	constexpr ModInt(std::uint32_t x = 0): value(x % MOD) {}
	constexpr operator std::uint32_t(){ return value; }
	constexpr ModInt operator-() const { return value == 0 ? 0 : MOD - value; }
	constexpr ModInt &operator+=(const ModInt &other){ if(value < MOD - other.value) value += other.value; else value -= MOD - other.value; return *this; }
	constexpr ModInt &operator-=(const ModInt &other){ if(value >= other.value) value -= other.value; else value += MOD - other.value; return *this; }
	constexpr ModInt &operator*=(const ModInt &other){ value = static_cast<std::uint64_t>(value) * other.value % MOD; return *this; }
	constexpr ModInt &operator/=(const ModInt &other){ return *this *= other.inv(); }
	friend constexpr ModInt operator+(ModInt left, const ModInt &right){ return left += right; }
	friend constexpr ModInt operator-(ModInt left, const ModInt &right){ return left -= right; }
	friend constexpr ModInt operator*(ModInt left, const ModInt &right){ return left *= right; }
	friend constexpr ModInt operator/(ModInt left, const ModInt &right){ return left /= right; }
	constexpr ModInt inv() const;
	constexpr ModInt pow(unsigned) const;
};

template <class T>
void fourier_transform(std::vector<T> &x, const std::vector<T> &zetas, bool inverse){
	std::size_t length = x.size(), mask = length - 1, bit = length / 2;
	std::vector<T> tmp(length);
	while(bit > 0){
		for(std::size_t i = 0; i < length; ++i){
			std::size_t lower = i & (bit - 1);
			std::size_t upper = i ^ lower;
			std::size_t gapped = upper << 1 & mask | lower;
			if(inverse && upper) upper = length - upper;
			tmp[i] = x[gapped] + zetas[upper] * x[gapped | bit];
		}
		std::swap(x, tmp);
		bit /= 2;
	}
}

template <class T, class Iterator>
void poly_scale(Iterator begin, Iterator end, const T &scalar){
	for(Iterator it = begin; it != end; ++it) *it *= scalar;
}

using Mint = ModInt<998244353>;

std::vector<Mint> zetas998244353(const unsigned log2length) {
	std::size_t length = (std::size_t)1 << log2length;
	auto zeta = Mint(3).pow(119);
	for(unsigned i = 0; i < 23 - log2length; ++i) zeta *= zeta;
	std::vector<Mint> ret(length, 1);
	for(std::size_t i = 1; i < length; ++i) ret[i] = ret[i - 1] * zeta;
	return ret;
}

void poly_mul(std::vector<Mint> &x, std::vector<Mint> y){
	unsigned bit = bit_width(x.size() + y.size() - 2);
	auto n = (std::size_t)1 << bit;
	x.resize(n, 0);
	y.resize(n, 0);
	auto zetas = zetas998244353(bit);
	fourier_transform(x, zetas, false);
	fourier_transform(y, zetas, false);
	for(std::size_t i = 0; i < n; ++i) x[i] *= y[i];
	fourier_transform(x, zetas, true);
	poly_scale(x.begin(), x.end(), Mint(n).inv());
}

std::vector<Mint> poly_inv(std::vector<Mint> f, unsigned target_bit){
	auto target_length = (std::size_t)1 << target_bit;
	std::vector<Mint> g(target_length);
	f.resize(target_length);
	g[0] = f[0].inv();
	for(unsigned i = 0; i < target_bit; ++i){
		std::size_t prev_length = (std::size_t)1 << i;
		std::size_t next_length = (std::size_t)1 << i + 1;
		std::vector<Mint> partial_g(next_length);
		std::copy(g.begin(), g.begin() + prev_length, partial_g.begin());
		auto zetas = zetas998244353(i + 1);
		fourier_transform(partial_g, zetas, false);
		std::vector<Mint> tmp(f.begin(), f.begin() + next_length);
		fourier_transform(tmp, zetas, false);
		for(std::size_t i = 0; i < next_length; ++i) tmp[i] *= partial_g[i];
		fourier_transform(tmp, zetas, true);
		for(std::size_t i = 0; i < prev_length; ++i) tmp[i] = 0;
		poly_scale(tmp.begin() + prev_length, tmp.begin() + next_length, Mint(next_length).inv());
		fourier_transform(tmp, zetas, false);
		for(std::size_t i = 0; i < next_length; ++i) tmp[i] *= partial_g[i];
		fourier_transform(tmp, zetas, true);
		poly_scale(tmp.begin(), tmp.end(), Mint(next_length).inv());
		for(std::size_t i = prev_length; i < next_length; ++i) g[i] = -tmp[i];
	}
	return g;
}

std::vector<Mint> fibonacci_polynomial(unsigned n) {
	std::vector<Mint> ret(n / 2 + 1);
	Mint tmp = ret[0] = 1;
	for(unsigned i = 0; i < n / 2; ++i){
		tmp *= n - i * 2;
		tmp *= n - i * 2 - 1;
		tmp /= i + 1;
		tmp /= n - i;
		ret[i + 1] = i % 2 ? tmp : -tmp;
	}
	return ret;
}

void poly_diff(std::vector<Mint> &x){
	for(std::size_t i = 0; i < x.size(); ++i) x[i] *= i;
	if(x.size() > 1) x.erase(x.begin());
}

int main(int argc, char *argv[]){
	registerValidation(argc, argv);
	std::size_t n = inf.readInt(1, 200000, "n");
	inf.readSpace();
	std::size_t k = inf.readInt(1, n, "k");
	inf.readEoln();
	inf.readEof();

	Mint entire = 1;
	for(std::size_t i = 0; i < n; ++i){
		entire *= n * 2 - i;
		entire /= i + 1;
	}

	unsigned bit = bit_width(n);

	auto p = fibonacci_polynomial(k);
	poly_diff(p);
	auto q = poly_inv(fibonacci_polynomial(k - 1), bit);
	if(q.size() > n + 1) q.resize(n + 1);
	poly_mul(p, std::move(q));

	auto r = fibonacci_polynomial(k + 1);
	poly_diff(r);
	auto s = poly_inv(fibonacci_polynomial(k), bit);
	if(s.size() > n + 1) s.resize(n + 1);
	poly_mul(r, std::move(s));

	std::cout << entire - p[n] + r[n] << std::endl;
}

template <std::uint32_t MOD>
constexpr ModInt<MOD> ModInt<MOD>::inv() const {
	unsigned a = MOD, s = 0;
	unsigned b = value, t = 1;
	bool sign = true;
	while(b){
		s += t * (a / b);
		a %= b;
		sign = !sign;
		std::swap(a, b);
		std::swap(s, t);
	}
	if(sign) s = MOD - s;
	return s;
}

template <std::uint32_t MOD>
constexpr ModInt<MOD> ModInt<MOD>::pow(unsigned n) const {
	ModInt tmp = *this, ret = 1;
	while(n){
		if(n % 2) ret *= tmp;
		tmp *= tmp;
		n /= 2;
	}
	return ret;
}

0