結果

問題 No.3273 Exactly One Match
ユーザー 遭難者
提出日時 2025-08-11 19:55:10
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 785 ms / 4,000 ms
コード長 12,584 bytes
コンパイル時間 5,928 ms
コンパイル使用メモリ 334,820 KB
実行使用メモリ 72,656 KB
最終ジャッジ日時 2025-09-12 23:48:24
合計ジャッジ時間 14,854 ms
ジャッジサーバーID
(参考情報)
judge6 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 26
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using std::cerr;
using std::cin;
using std::cout;
#if __has_include(<atcoder/all>)
#include <atcoder/all>
using mint = atcoder::modint998244353;
istream &operator>>(istream &is, mint &a) {
	long long t;
	is >> t;
	a = t;
	return is;
}
ostream &operator<<(ostream &os, mint a) { return os << a.val(); }
#endif
typedef long double ld;
#define long long long
#define uint unsigned int
#define ull unsigned long
#define overload3(a, b, c, name, ...) name
#define rep3(i, a, b) for (int i = (a); i < (b); i++)
#define rep2(i, n) rep3(i, 0, n)
#define rep1(n) rep2(__i, n)
#define rep(...) overload3(__VA_ARGS__, rep3, rep2, rep1)(__VA_ARGS__)
#define per3(i, a, b) for (int i = (b) - 1; i >= (a); i--)
#define per2(i, n) per3(i, 0, n)
#define per1(n) per2(__i, n)
#define per(...) overload3(__VA_ARGS__, per3, per2, per1)(__VA_ARGS__)
#define all(a) a.begin(), a.end()
#define UNIQUE(a)                                                              \
	sort(all(a));                                                              \
	a.erase(unique(all(a)), a.end())
#define sz(a) (int)a.size()
#define vec vector
#ifndef DEBUG
#define cerr                                                                   \
	if (0)                                                                     \
	cerr
#undef assert
#define assert(...) void(0)
#undef endl
#define endl '\n'
#endif
template <typename T> ostream &operator<<(ostream &os, vector<T> a) {
	const int n = a.size();
	rep(i, n) {
		os << a[i];
		if (i + 1 != n)
			os << " ";
	}
	return os;
}
template <typename T, size_t n>
ostream &operator<<(ostream &os, array<T, n> a) {
	rep(i, n) {
		os << a[i];
		if (i + 1 != n)
			os << " ";
	}
	return os;
}
template <typename T> istream &operator>>(istream &is, vector<T> &a) {
	for (T &i : a)
		is >> i;
	return is;
}
template <typename T, typename S> bool chmin(T &x, S y) {
	if ((T)y < x) {
		x = (T)y;
		return true;
	}
	return false;
}
template <typename T, typename S> bool chmax(T &x, S y) {
	if (x < (T)y) {
		x = (T)y;
		return true;
	}
	return false;
}
template <typename T> void operator++(vector<T> &a) {
	for (T &i : a)
		++i;
}
template <typename T> void operator--(vector<T> &a) {
	for (T &i : a)
		--i;
}
template <typename T> void operator++(vector<T> &a, int) {
	for (T &i : a)
		i++;
}
template <typename T> void operator--(vector<T> &a, int) {
	for (T &i : a)
		i--;
}
using namespace atcoder;
using ll = long;
template <class T> vector<T> NTT(vector<T> a, vector<T> b) {
	ll nmod = T::mod();
	int n = a.size();
	int m = b.size();
	vector<int> x1(n);
	vector<int> y1(m);
	for (int i = 0; i < n; i++) {
		ll tmp1, tmp2, tmp3;
		tmp1 = a[i].val();
		x1[i] = tmp1;
	}
	for (int i = 0; i < m; i++) {
		ll tmp1, tmp2, tmp3;
		tmp1 = b[i].val();
		y1[i] = tmp1;
	}
	auto z1 = convolution<167772161>(x1, y1);
	auto z2 = convolution<469762049>(x1, y1);
	auto z3 = convolution<1224736769>(x1, y1);
	vector<T> res(n + m - 1);
	constexpr ll m1 = 167772161;
	constexpr ll m2 = 469762049;
	constexpr ll m3 = 1224736769;
	constexpr ll m1m2 = 104391568;
	constexpr ll m1m2m3 = 721017874;
	ll mm12 = m1 * m2 % nmod;
	for (int i = 0; i < n + m - 1; i++) {
		int v1 = (z2[i] - z1[i]) * m1m2 % m2;
		if (v1 < 0)
			v1 += m2;
		int v2 = (z3[i] - (z1[i] + v1 * m1) % m3) * m1m2m3 % m3;
		if (v2 < 0)
			v2 += m3;
		res[i] = (z1[i] + v1 * m1 + v2 * mm12);
	}
	return res;
}
template <class T> struct FormalPowerSeries : vector<T> {
	using vector<T>::vector;
	using F = FormalPowerSeries;
	F &operator=(const vector<T> &g) {
		int n = g.size();
		int m = (*this).size();
		if (m < n)
			(*this).resize(n);
		for (int i = 0; i < n; i++)
			(*this)[i] = g[i];
		return (*this);
	}
	F &operator=(const F &g) {
		int n = g.size();
		int m = (*this).size();
		if (m < n)
			(*this).resize(n);
		for (int i = 0; i < n; i++)
			(*this)[i] = g[i];
		return (*this);
	}
	F &operator-() {
		for (int i = 0; i < (*this).size(); i++)
			(*this)[i] *= -1;
		return (*this);
	}
	F &operator+=(const F &g) {
		int n = (*this).size();
		int m = g.size();
		if (n < m)
			(*this).resize(m);
		for (int i = 0; i < m; i++)
			(*this)[i] += g[i];
		return (*this);
	}
	F &operator+=(const T &r) {
		if ((*this).size() == 0)
			(*this).resize(1);
		(*this)[0] += r;
		return (*this);
	}
	F &operator-=(const F &g) {
		int n = (*this).size();
		int m = g.size();
		if (n < m)
			(*this).resize(m);
		for (int i = 0; i < m; i++)
			(*this)[i] -= g[i];
		return (*this);
	}
	F &operator-=(const T &r) {
		if ((*this).size() == 0)
			(*this).resize(1);
		(*this)[0] -= r;
		return (*this);
	}
	F &operator*=(const F &g) {
		(*this) = convolution((*this), g);
		return (*this);
	}
	F &operator*=(const T &r) {
		for (int i = 0; i < (*this).size(); i++)
			(*this)[i] *= r;
		return (*this);
	}
	F &operator/=(const F &g) {
		int n = (*this).size();
		(*this) = convolution((*this), g.inv());
		(*this).resize(n);
		return (*this);
	}
	F &operator/=(const T &r) {
		r = r.inv();
		for (int i = 0; i < (*this).size(); i++)
			(*this)[i] *= r;
		return (*this);
	}
	F &operator<<=(const int d) {
		int n = (*this).size();
		(*this).insert((*this).begin(), d, 0);
		(*this).resize(n);
		return *this;
	}
	F &operator>>=(const int d) {
		int n = (*this).size();
		(*this).erase((*this).begin(), (*this).begin() + min(n, d));
		(*this).resize(n);
		return *this;
	}
	F operator*(const T &g) const { return F(*this) *= g; }
	F operator-(const T &g) const { return F(*this) -= g; }
	F operator+(const T &g) const { return F(*this) += g; }
	F operator/(const T &g) const { return F(*this) /= g; }
	F operator*(const F &g) const { return F(*this) *= g; }
	F operator-(const F &g) const { return F(*this) -= g; }
	F operator+(const F &g) const { return F(*this) += g; }
	F operator/(const F &g) const { return F(*this) /= g; }
	F operator%(const F &g) const { return F(*this) %= g; }
	F operator<<(const int d) const { return F(*this) <<= d; }
	F operator>>(const int d) const { return F(*this) >>= d; }
	F pre(int sz) const {
		return F(begin(*this), begin(*this) + min((int)this->size(), sz));
	}
	F inv(int deg = -1) const {
		int n = (*this).size();
		if (deg == -1)
			deg = n;
		assert(n > 0 && (*this)[0] != T(0));
		F g(1);
		g[0] = (*this)[0].inv();
		while (g.size() < deg) {
			int m = g.size();
			F f(begin(*this), begin(*this) + min(n, 2 * m));
			F r(g);
			f.resize(2 * m);
			r.resize(2 * m);
			internal::butterfly(f);
			internal::butterfly(r);
			for (int i = 0; i < 2 * m; i++)
				f[i] *= r[i];
			internal::butterfly_inv(f);
			f.erase(f.begin(), f.begin() + m);
			f.resize(2 * m);
			internal::butterfly(f);
			for (int i = 0; i < 2 * m; i++)
				f[i] *= r[i];
			internal::butterfly_inv(f);
			T in = T(2 * m).inv();
			in *= -in;
			for (int i = 0; i < m; i++)
				f[i] *= in;
			g.insert(g.end(), f.begin(), f.begin() + m);
		}
		return g.pre(deg);
	}
	T eval(const T &a) {
		T x = 1;
		T ret = 0;
		for (int i = 0; i < (*this).size(); i++) {
			ret += (*this)[i] * x;
			x *= a;
		}
		return ret;
	}
	void onemul(const int d, const T c) {
		int n = (*this).size();
		for (int i = n - d - 1; i >= 0; i--) {
			(*this)[i + d] += (*this)[i] * c;
		}
	}
	void onediv(const int d, const T c) {
		int n = (*this).size();
		for (int i = 0; i < n - d; i++) {
			(*this)[i + d] -= (*this)[i] * c;
		}
	}
	F diff() const {
		int n = (*this).size();
		F ret(n);
		for (int i = 1; i < n; i++)
			ret[i - 1] = (*this)[i] * i;
		ret[n - 1] = 0;
		return ret;
	}
	F integral() const {
		int n = (*this).size(), mod = T::mod();
		vector<T> inv(n);
		inv[1] = 1;
		for (int i = 2; i < n; i++)
			inv[i] = T(mod) - inv[mod % i] * (mod / i);
		F ret(n);
		for (int i = n - 2; i >= 0; i--)
			ret[i + 1] = (*this)[i] * inv[i + 1];
		ret[0] = 0;
		return ret;
	}
	F log(int deg = -1) const {
		int n = (*this).size();
		if (deg == -1)
			deg = n;
		assert((*this)[0] == T(1));
		return ((*this).diff() * (*this).inv(deg)).pre(deg).integral();
	}
	F exp(int deg = -1) const {
		int n = (*this).size();
		if (deg == -1)
			deg = n;
		assert(n == 0 || (*this)[0] == 0);
		F Inv;
		Inv.reserve(deg);
		Inv.push_back(T(0));
		Inv.push_back(T(1));
		auto inplace_integral = [&](F &f) -> void {
			const int n = (int)f.size();
			int mod = T::mod();
			while (Inv.size() <= n) {
				int i = Inv.size();
				Inv.push_back((-Inv[mod % i]) * (mod / i));
			}
			f.insert(begin(f), T(0));
			for (int i = 1; i <= n; i++)
				f[i] *= Inv[i];
		};
		auto inplace_diff = [](F &f) -> void {
			if (f.empty())
				return;
			f.erase(begin(f));
			T coeff = 1, one = 1;
			for (int i = 0; i < f.size(); i++) {
				f[i] *= coeff;
				coeff++;
			}
		};
		F b{1, 1 < (int)(*this).size() ? (*this)[1] : 0}, c{1}, z1, z2{1, 1};
		for (int m = 2; m <= deg; m <<= 1) {
			auto y = b;
			y.resize(2 * m);
			internal::butterfly(y);
			z1 = z2;
			F z(m);
			for (int i = 0; i < m; i++)
				z[i] = y[i] * z1[i];
			internal::butterfly_inv(z);
			T si = T(m).inv();
			for (int i = 0; i < m; i++)
				z[i] *= si;
			fill(begin(z), begin(z) + m / 2, T(0));
			internal::butterfly(z);
			for (int i = 0; i < m; i++)
				z[i] *= -z1[i];
			internal::butterfly_inv(z);
			for (int i = 0; i < m; i++)
				z[i] *= si;
			c.insert(end(c), begin(z) + m / 2, end(z));
			z2 = c;
			z2.resize(2 * m);
			internal::butterfly(z2);
			F x(begin((*this)), begin((*this)) + min<int>((*this).size(), m));
			x.resize(m);
			inplace_diff(x);
			x.push_back(T(0));
			internal::butterfly(x);
			for (int i = 0; i < m; i++)
				x[i] *= y[i];
			internal::butterfly_inv(x);
			for (int i = 0; i < m; i++)
				x[i] *= si;
			x -= b.diff();
			x.resize(2 * m);
			for (int i = 0; i < m - 1; i++)
				x[m + i] = x[i], x[i] = T(0);
			internal::butterfly(x);
			for (int i = 0; i < 2 * m; i++)
				x[i] *= z2[i];
			internal::butterfly_inv(x);
			T si2 = T(m << 1).inv();
			for (int i = 0; i < 2 * m; i++)
				x[i] *= si2;
			x.pop_back();
			inplace_integral(x);
			for (int i = m; i < min<int>((*this).size(), 2 * m); i++)
				x[i] += (*this)[i];
			fill(begin(x), begin(x) + m, T(0));
			internal::butterfly(x);
			for (int i = 0; i < 2 * m; i++)
				x[i] *= y[i];
			internal::butterfly_inv(x);
			for (int i = 0; i < 2 * m; i++)
				x[i] *= si2;
			b.insert(end(b), begin(x) + m, end(x));
		}
		return b.pre(deg);
	}
	F pow(ll m) {
		int n = (*this).size();
		if (m == 0) {
			F ret(n);
			ret[0] = 1;
			return ret;
		}
		int x = 0;
		while (x < n && (*this)[x] == T(0))
			x++;
		if (x >= (n + m - 1) / m) {
			F ret(n);
			return ret;
		}
		F f(n - x);
		T y = (*this)[x];
		for (int i = x; i < n; i++)
			f[i - x] = (*this)[i] / y;
		f = f.log();
		for (int i = 0; i < n - x; i++)
			f[i] *= m;
		f = f.exp();
		y = y.pow(m);
		for (int i = 0; i < n - x; i++)
			f[i] *= y;
		F ret(n);
		const ll xm = x * m;
		for (int i = xm; i < n; i++)
			ret[i] = f[i - xm];
		return ret;
	}
	F shift(T c) {
		int n = (*this).size();
		int mod = T::mod();
		vector<T> inv(n + 1);
		inv[1] = 1;
		for (int i = 2; i <= n; i++)
			inv[i] = mod - inv[mod % i] * (mod / i);
		T x = 1;
		for (int i = 0; i < n; i++) {
			(*this)[i] *= x;
			x *= (i + 1);
		}
		F g(n);
		T y = 1;
		T now = 1;
		for (int i = 0; i < n; i++) {
			g[n - i - 1] = now * y;
			now *= c;
			y *= inv[i + 1];
		}
		auto tmp = convolution(g, (*this));
		T z = 1;
		for (int i = 0; i < n; i++) {
			(*this)[i] = tmp[n + i - 1] * z;
			z *= inv[i + 1];
		}
		return (*this);
	}
};
using fps = FormalPowerSeries<mint>;
constexpr int INF = 1e6 + 2022;
mint fact[INF + 1], finv[INF + 1];
void solve() {
	vec<mint> fact(INF + 1), finv(INF + 1);
	fact[0] = 1;
	for (int i = 1; i <= INF; i++)
		fact[i] = i * fact[i - 1];
	finv[INF] = fact[INF].inv();
	for (int i = INF; i > 0; i--)
		finv[i - 1] = i * finv[i];
	int n;
	mint k;
	cin >> n >> k;
	k--;
	vec<mint> kpow(n + 1);
	kpow[0] = 1;
	for (int i = 1; i <= n; i++)
		kpow[i] = k * kpow[i - 1];
	vec<mint> c(n + 1);
	for (int i = 0; i <= n; i++)
		c[i] = kpow[i] + (i % 2 ? -k : k);
	assert(c[1].val() == 0);
	fps f0(n + 1), f1(n + 1), g0(n + 1);
	for (int i = 1; i <= n; i++) {
		f0[i] = c[i] * fact[i - 1] * finv[i];
		f1[i] = c[i - 1];
	}
	for (int i = 0; i <= n; i++)
		g0[i] = mint(n).pow(998244353LL - 1 + n - 1 - i) * i * finv[n - i];
	f0 = f0.exp();
	f1 *= f0;
	f0[0] = 0;
	mint ans = 0;
	rep(s, n) ans += f0[s] * g0[s] * kpow[n - s - 1] * (n - s);
	rep(s, n + 1) ans += f1[s] * g0[s] * kpow[n - s];
	cout << ans * fact[n] << endl;
}
int main() {
	// srand((unsigned)time(NULL));
	cin.tie(nullptr);
	ios::sync_with_stdio(false);
	cout << fixed << setprecision(20);
	int t = 1;
	// cin >> t;
	while (t--)
		solve();
}
0