結果

問題 No.2670 Sum of Products of Interval Lengths
ユーザー 沙耶花沙耶花
提出日時 2024-03-08 22:48:07
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 415 ms / 2,000 ms
コード長 5,950 bytes
コンパイル時間 4,762 ms
コンパイル使用メモリ 273,744 KB
実行使用メモリ 11,252 KB
最終ジャッジ日時 2024-09-29 20:17:49
合計ジャッジ時間 10,301 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 3 ms
6,816 KB
testcase_01 AC 414 ms
11,120 KB
testcase_02 AC 181 ms
7,416 KB
testcase_03 AC 46 ms
6,816 KB
testcase_04 AC 182 ms
7,408 KB
testcase_05 AC 372 ms
10,104 KB
testcase_06 AC 388 ms
10,612 KB
testcase_07 AC 182 ms
7,288 KB
testcase_08 AC 46 ms
6,816 KB
testcase_09 AC 180 ms
7,280 KB
testcase_10 AC 373 ms
10,232 KB
testcase_11 AC 389 ms
10,608 KB
testcase_12 AC 414 ms
11,252 KB
testcase_13 AC 415 ms
11,252 KB
testcase_14 AC 412 ms
11,124 KB
testcase_15 AC 413 ms
11,128 KB
testcase_16 AC 413 ms
11,236 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <stdio.h>
#include <bits/stdc++.h>
#include <atcoder/all>
using namespace atcoder;
using mint = modint998244353;
using namespace std;
#define rep(i,n) for (int i = 0; i < (n); ++i)
#define Inf32 1000000001
#define Inf64 1000000000000000001
// https://nyaannyaan.github.io/library/fps/formal-power-series.hpp
struct FormalPowerSeries : vector<mint> {
	using vector<mint>::vector;
	using FPS = FormalPowerSeries;

	FPS &operator+=(const FPS &r) {
		if (r.size() > this->size()) this->resize(r.size());
		for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i];
		return *this;
	}

	FPS &operator+=(const mint &r) {
		if (this->empty()) this->resize(1);
		(*this)[0] += r;
		return *this;
	}

	FPS &operator-=(const FPS &r) {
		if (r.size() > this->size()) this->resize(r.size());
		for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i];
		return *this;
	}

	FPS &operator-=(const mint &r) {
		if (this->empty()) this->resize(1);
		(*this)[0] -= r;
		return *this;
	}
	
	FPS &operator*=(const FPS &r) {
		auto ret = convolution(r,*this);
		return (*this) = FPS(ret.begin(),ret.end());
	}

	FPS &operator*=(const mint &v) {
		for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v;
		return *this;
	}	

	FPS &operator/=(const FPS &r) {
		if (this->size() < r.size()) {
			this->clear();
			return *this;
		}
		int n = this->size() - r.size() + 1;
		if ((int)r.size() <= 64) {
			FPS f(*this), g(r);
			g.shrink();
			mint coeff = g.back().inv();
			for (auto &x : g) x *= coeff;
			int deg = (int)f.size() - (int)g.size() + 1;
			int gs = g.size();
			FPS quo(deg);
			for (int i = deg - 1; i >= 0; i--) {
				quo[i] = f[i + gs - 1];
				for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
			}
			*this = quo * coeff;
			this->resize(n, mint(0));
			return *this;
		}
		return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev();
	}

	FPS &operator%=(const FPS &r) {
		*this -= *this / r * r;
		shrink();
		return *this;
	}

	FPS operator+(const FPS &r) const { return FPS(*this) += r; }
	FPS operator+(const mint &v) const { return FPS(*this) += v; }
	FPS operator-(const FPS &r) const { return FPS(*this) -= r; }
	FPS operator-(const mint &v) const { return FPS(*this) -= v; }
	FPS operator*(const FPS &r) const { return FPS(*this) *= r; }
	FPS operator*(const mint &v) const { return FPS(*this) *= v; }
	FPS operator/(const FPS &r) const { return FPS(*this) /= r; }
	FPS operator%(const FPS &r) const { return FPS(*this) %= r; }
	FPS operator-() const {
	FPS ret(this->size());
		for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i];
		return ret;
	}

	void shrink() {
		while (this->size() && this->back() == mint(0)) this->pop_back();
	}

	FPS rev() const {
		FPS ret(*this);
		reverse(ret.begin(), ret.end());
		return ret;
	}

	FPS dot(FPS r) const {
		FPS ret(min(this->size(), r.size()));
		for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i];
		return ret;
	}

	FPS pre(int sz) const {
		return FPS((*this).begin(), (*this).begin() + min((int)this->size(), sz));
	}

	FPS operator>>(int sz) const {
		if ((int)this->size() <= sz) return {};
		FPS ret(*this);
		ret.erase(ret.begin(), ret.begin() + sz);
		return ret;
	}

	FPS operator<<(int sz) const {
		FPS ret(*this);
		ret.insert(ret.begin(), sz, mint(0));
		return ret;
	}

	FPS diff() const {
		const int n = (int)this->size();
		FPS ret(max(0, n - 1));
		mint one(1), coeff(1);
		for (int i = 1; i < n; i++) {
			ret[i - 1] = (*this)[i] * coeff;
			coeff += one;
		}
		return ret;
	}

	FPS integral() const {
		const int n = (int)this->size();
		FPS ret(n + 1);
		ret[0] = mint(0);
		if (n > 0) ret[1] = mint(1);
		auto mod = mint::mod();
		for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i);
		for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i];
		return ret;
	}

	mint eval(mint x) const {
		mint r = 0, w = 1;
		for (auto &v : *this) r += w * v, w *= x;
		return r;
	}

	FPS log(int deg = -1) const {
		assert((*this)[0] == mint(1));
		if (deg == -1) deg = (int)this->size();
		return (this->diff() * this->inv(deg)).pre(deg - 1).integral();
	}

	FPS pow(int64_t k, int deg = -1) const {
		const int n = (int)this->size();
		if (deg == -1) deg = n;
		if (k == 0) {
			FPS ret(deg);
			if (deg) ret[0] = 1;
			return ret;
		}
		for (int i = 0; i < n; i++) {
			if ((*this)[i] != mint(0)) {
				mint rev = mint(1) / (*this)[i];
				FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg);
				ret *= (*this)[i].pow(k);
				ret = (ret << (i * k)).pre(deg);
				if ((int)ret.size() < deg) ret.resize(deg, mint(0));
				return ret;
			}
			if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0));
		}
		return FPS(deg, mint(0));
	}

	FPS inv(int deg = -1) const {
		assert((*this)[0] != mint(0));
		if (deg == -1) deg = (*this).size();
		FPS ret({mint(1) / (*this)[0]});
		for (int i = 1; i < deg; i <<= 1)
		ret = (ret + ret - ret * ret * (*this).pre(i << 1)).pre(i << 1);
		return ret.pre(deg);
	}
	
	FPS exp(int deg = -1) const{
		assert((*this).size() == 0 || (*this)[0] == mint(0));
		if (deg == -1) deg = (int)this->size();
		FPS ret({mint(1)});
		for (int i = 1; i < deg; i <<= 1) {
			ret = (ret * (pre(i << 1) + mint(1) - ret.log(i << 1))).pre(i << 1);
		}
		return ret.pre(deg);
	}
};



using fps = FormalPowerSeries;
long long n,m;

mint get(long long t){
	long long ret = max(0LL,m-t+1);
	if(t%3==0)ret = 0;
	else if(t%6>=4)ret *= -1;
	return ret;
}
mint dp[200005];
void dfs(int l,int r){
	if(r-l==1){
		dp[l] += get(l+1);
	
		return;
	}
	else if(r-l<=1)return;
	int m = (l+r)/2;
	dfs(l,m);
	
	vector<mint> x(r-l);
	for(int i=l;i<m;i++){
		x[i-l] = dp[i];
	}
	vector<mint> y(r-l);
	rep(i,y.size())y[i] = get(i);
	x =convolution(x,y);
	rep(i,x.size()){
		int t = i + l;
		if(t >= m && t < r)dp[t] += x[i];
	}
	dfs(m,r);
}

int main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	
	
	cin>>n>>m;
	
	dfs(0,n);
	//rep(i,n)cout<<dp[i].val()<<endl;
	cout<<dp[n-1].val()<<endl;
	return 0;
}
0