結果

問題 No.1939 Numbered Colorful Balls
ユーザー 沙耶花沙耶花
提出日時 2023-01-25 22:08:13
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 297 ms / 2,000 ms
コード長 6,129 bytes
コンパイル時間 5,316 ms
コンパイル使用メモリ 286,128 KB
実行使用メモリ 35,580 KB
最終ジャッジ日時 2024-06-27 00:01:00
合計ジャッジ時間 11,757 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 291 ms
35,580 KB
testcase_01 AC 60 ms
27,904 KB
testcase_02 AC 293 ms
35,468 KB
testcase_03 AC 293 ms
35,456 KB
testcase_04 AC 60 ms
27,904 KB
testcase_05 AC 293 ms
35,576 KB
testcase_06 AC 293 ms
35,544 KB
testcase_07 AC 113 ms
29,748 KB
testcase_08 AC 85 ms
29,020 KB
testcase_09 AC 295 ms
34,792 KB
testcase_10 AC 63 ms
28,032 KB
testcase_11 AC 173 ms
31,584 KB
testcase_12 AC 72 ms
28,416 KB
testcase_13 AC 63 ms
28,032 KB
testcase_14 AC 113 ms
29,560 KB
testcase_15 AC 86 ms
28,864 KB
testcase_16 AC 174 ms
31,908 KB
testcase_17 AC 172 ms
31,584 KB
testcase_18 AC 115 ms
29,760 KB
testcase_19 AC 290 ms
34,904 KB
testcase_20 AC 114 ms
29,580 KB
testcase_21 AC 295 ms
34,624 KB
testcase_22 AC 73 ms
28,416 KB
testcase_23 AC 172 ms
31,744 KB
testcase_24 AC 295 ms
34,700 KB
testcase_25 AC 173 ms
31,488 KB
testcase_26 AC 297 ms
35,536 KB
testcase_27 AC 61 ms
27,904 KB
testcase_28 AC 61 ms
27,904 KB
testcase_29 AC 173 ms
31,676 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <stdio.h>
#include <atcoder/all>
#include <bits/stdc++.h>
using namespace std;
using namespace atcoder;
using mint = modint998244353;
#define rep(i,n) for (int i = 0; i < (n); ++i)
#define Inf32 1000000001
#define Inf64 1000000000000000001

struct combi{
	deque<mint> kaijou;
	deque<mint> kaijou_;
	
	combi(int n){
		kaijou.push_back(1);
		for(int i=1;i<=n;i++){
			kaijou.push_back(kaijou[i-1]*i);
		}
		
		mint b=kaijou[n].inv();
		
		kaijou_.push_front(b);
		for(int i=1;i<=n;i++){
			int k=n+1-i;
			kaijou_.push_front(kaijou_[0]*k);
		}
	}
	
	mint combination(int n,int r){
		if(r>n)return 0;
		mint a = kaijou[n]*kaijou_[r];
		a *= kaijou_[n-r];
		return a;
	}
	
	mint junretsu(int a,int b){
		mint x = kaijou_[a]*kaijou_[b];
		x *= kaijou[a+b];
		return x;
	}
	
	mint catalan(int n){
		return combination(2*n,n)/(n+1);
	}
	
};



combi C(3000000);

int n,m;
vector<int> L;


// https://nyaannyaan.github.io/library/fps/formal-power-series.hpp
struct FormalPowerSeries : vector<mint> {
	using vector<mint>::vector;
	using FPS = FormalPowerSeries;

	FPS &operator+=(const FPS &r) {
		if (r.size() > this->size()) this->resize(r.size());
		for (int i = 0; i < (int)r.size(); i++) (*this)[i] += r[i];
		return *this;
	}

	FPS &operator+=(const mint &r) {
		if (this->empty()) this->resize(1);
		(*this)[0] += r;
		return *this;
	}

	FPS &operator-=(const FPS &r) {
		if (r.size() > this->size()) this->resize(r.size());
		for (int i = 0; i < (int)r.size(); i++) (*this)[i] -= r[i];
		return *this;
	}

	FPS &operator-=(const mint &r) {
		if (this->empty()) this->resize(1);
		(*this)[0] -= r;
		return *this;
	}
	
	FPS &operator*=(const FPS &r) {
		auto ret = convolution(r,*this);
		return (*this) = FPS(ret.begin(),ret.end());
	}

	FPS &operator*=(const mint &v) {
		for (int k = 0; k < (int)this->size(); k++) (*this)[k] *= v;
		return *this;
	}	

	FPS &operator/=(const FPS &r) {
		if (this->size() < r.size()) {
			this->clear();
			return *this;
		}
		int n = this->size() - r.size() + 1;
		if ((int)r.size() <= 64) {
			FPS f(*this), g(r);
			g.shrink();
			mint coeff = g.back().inv();
			for (auto &x : g) x *= coeff;
			int deg = (int)f.size() - (int)g.size() + 1;
			int gs = g.size();
			FPS quo(deg);
			for (int i = deg - 1; i >= 0; i--) {
				quo[i] = f[i + gs - 1];
				for (int j = 0; j < gs; j++) f[i + j] -= quo[i] * g[j];
			}
			*this = quo * coeff;
			this->resize(n, mint(0));
			return *this;
		}
		return *this = ((*this).rev().pre(n) * r.rev().inv(n)).pre(n).rev();
	}

	FPS &operator%=(const FPS &r) {
		*this -= *this / r * r;
		shrink();
		return *this;
	}

	FPS operator+(const FPS &r) const { return FPS(*this) += r; }
	FPS operator+(const mint &v) const { return FPS(*this) += v; }
	FPS operator-(const FPS &r) const { return FPS(*this) -= r; }
	FPS operator-(const mint &v) const { return FPS(*this) -= v; }
	FPS operator*(const FPS &r) const { return FPS(*this) *= r; }
	FPS operator*(const mint &v) const { return FPS(*this) *= v; }
	FPS operator/(const FPS &r) const { return FPS(*this) /= r; }
	FPS operator%(const FPS &r) const { return FPS(*this) %= r; }
	FPS operator-() const {
	FPS ret(this->size());
		for (int i = 0; i < (int)this->size(); i++) ret[i] = -(*this)[i];
		return ret;
	}

	void shrink() {
		while (this->size() && this->back() == mint(0)) this->pop_back();
	}

	FPS rev() const {
		FPS ret(*this);
		reverse(ret.begin(), ret.end());
		return ret;
	}

	FPS dot(FPS r) const {
		FPS ret(min(this->size(), r.size()));
		for (int i = 0; i < (int)ret.size(); i++) ret[i] = (*this)[i] * r[i];
		return ret;
	}

	FPS pre(int sz) const {
		return FPS((*this).begin(), (*this).begin() + min((int)this->size(), sz));
	}

	FPS operator>>(int sz) const {
		if ((int)this->size() <= sz) return {};
		FPS ret(*this);
		ret.erase(ret.begin(), ret.begin() + sz);
		return ret;
	}

	FPS operator<<(int sz) const {
		FPS ret(*this);
		ret.insert(ret.begin(), sz, mint(0));
		return ret;
	}

	FPS diff() const {
		const int n = (int)this->size();
		FPS ret(max(0, n - 1));
		mint one(1), coeff(1);
		for (int i = 1; i < n; i++) {
			ret[i - 1] = (*this)[i] * coeff;
			coeff += one;
		}
		return ret;
	}

	FPS integral() const {
		const int n = (int)this->size();
		FPS ret(n + 1);
		ret[0] = mint(0);
		if (n > 0) ret[1] = mint(1);
		auto mod = mint::mod();
		for (int i = 2; i <= n; i++) ret[i] = (-ret[mod % i]) * (mod / i);
		for (int i = 0; i < n; i++) ret[i + 1] *= (*this)[i];
		return ret;
	}

	mint eval(mint x) const {
		mint r = 0, w = 1;
		for (auto &v : *this) r += w * v, w *= x;
		return r;
	}

	FPS log(int deg = -1) const {
		assert((*this)[0] == mint(1));
		if (deg == -1) deg = (int)this->size();
		return (this->diff() * this->inv(deg)).pre(deg - 1).integral();
	}

	FPS pow(int64_t k, int deg = -1) const {
		const int n = (int)this->size();
		if (deg == -1) deg = n;
		if (k == 0) {
			FPS ret(deg);
			if (deg) ret[0] = 1;
			return ret;
		}
		for (int i = 0; i < n; i++) {
			if ((*this)[i] != mint(0)) {
				mint rev = mint(1) / (*this)[i];
				FPS ret = (((*this * rev) >> i).log(deg) * k).exp(deg);
				ret *= (*this)[i].pow(k);
				ret = (ret << (i * k)).pre(deg);
				if ((int)ret.size() < deg) ret.resize(deg, mint(0));
				return ret;
			}
			if (__int128_t(i + 1) * k >= deg) return FPS(deg, mint(0));
		}
		return FPS(deg, mint(0));
	}

	FPS inv(int deg = -1) const {
		assert((*this)[0] != mint(0));
		if (deg == -1) deg = (*this).size();
		FPS ret({mint(1) / (*this)[0]});
		for (int i = 1; i < deg; i <<= 1)
		ret = (ret + ret - ret * ret * (*this).pre(i << 1)).pre(i << 1);
		return ret.pre(deg);
	}
	
	FPS exp(int deg = -1) const{
		assert((*this).size() == 0 || (*this)[0] == mint(0));
		if (deg == -1) deg = (int)this->size();
		FPS ret({mint(1)});
		for (int i = 1; i < deg; i <<= 1) {
			ret = (ret * (pre(i << 1) + mint(1) - ret.log(i << 1))).pre(i << 1);
		}
		return ret.pre(deg);
	}
};



using fps = FormalPowerSeries;

int main(){
	
	cin>>n>>m;
	L.resize(m);
	rep(i,m)cin>>L[i];
	
	fps f(n+5);
	rep(i,m){
		f[L[i]] = 1;
	}
	f[0] = 1;
	
	f = f.pow(n+1,n+1);
	mint ans = f[n];
	ans /= n+1;
	
	cout<<ans.val()<<endl;
	
	
	return 0;
}
0