結果
問題 | No.2487 Multiple of M |
ユーザー |
![]() |
提出日時 | 2023-09-29 23:59:40 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,193 bytes |
コンパイル時間 | 1,980 ms |
コンパイル使用メモリ | 207,372 KB |
最終ジャッジ日時 | 2025-02-17 03:41:47 |
ジャッジサーバーID (参考情報) |
judge3 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | WA * 3 |
other | WA * 53 |
ソースコード
#include <bits/stdc++.h>using namespace std;#ifndef CLASS_MODINT#define CLASS_MODINT#include <cstdint>template <std::uint32_t mod>class modint {private:std::uint32_t n;public:modint() : n(0) {};modint(std::int64_t n_) : n((n_ >= 0 ? n_ : mod - (-n_) % mod) % mod) {};static constexpr std::uint32_t get_mod() { return mod; }std::uint32_t get() const { return n; }bool operator==(const modint& m) const { return n == m.n; }bool operator!=(const modint& m) const { return n != m.n; }modint& operator+=(const modint& m) { n += m.n; n = (n < mod ? n : n - mod); return *this; }modint& operator-=(const modint& m) { n += mod - m.n; n = (n < mod ? n : n - mod); return *this; }modint& operator*=(const modint& m) { n = std::uint64_t(n) * m.n % mod; return *this; }modint operator+(const modint& m) const { return modint(*this) += m; }modint operator-(const modint& m) const { return modint(*this) -= m; }modint operator*(const modint& m) const { return modint(*this) *= m; }modint inv() const { return (*this).pow(mod - 2); }modint pow(std::uint64_t b) const {modint ans = 1, m = modint(*this);while (b) {if (b & 1) ans *= m;m *= m;b >>= 1;}return ans;}};#endif // CLASS_MODINTusing mint = modint<998244353>;long long mygcd(long long x, long long y) {if (y == 0) {return x;}return mygcd(y, x % y);}array<mint, 2> merge(long long X, const array<mint, 2>& v1, const array<mint, 2>& v2) {array<mint, 2> res = { mint(0), mint(0) };res[1] += v1[1] * v2[1];res[0] += v1[0] * v2[1];res[0] += v1[1] * v2[0];res[0] += v1[0] * v2[0] * (X - 2);res[1] += v1[0] * v2[0] * (X - 1);return res;}mint solve(long long N, long long M, long long K) {long long cur = 1;vector<long long> seq = { 1 };while (true) {cur = mygcd(cur * K, M);if (cur == seq.back()) {break;}seq.push_back(cur);}if (seq.size() > N) {seq.resize(N);}int V = seq.size();seq.push_back(K);vector<long long> d(V);for (int i = 0; i < V; i++) {d[i] = seq[i + 1] / seq[i];}vector<vector<mint> > dp(V);dp[0] = { mint(1) };for (int i = 1; i < V; i++) {dp[i].resize(i + 1);for (int j = 0; j < i; j++) {dp[i][j] += dp[i - 1][j] * seq[i - 1];dp[i][j + 1] += dp[i - 1][j];}for (int j = 0; j <= i; j++) {dp[i][j] *= (j >= 1 ? mint(d[i - 1]).pow(j - 1) : mint(1));}}array<mint, 2> ini = { mint(0), mint(0) };for (int i = 0; i < V; i++) {array<mint, 2> sub;sub[0] = dp[V - 1][i] * (i >= 1 ? mint(d[V - 1]).pow(i - 1) : mint(1));sub[1] = dp[V - 1][i] * (i >= 1 ? mint(d[V - 1]).pow(i - 1) : mint(0));if (i % 2 == 0) {ini[0] += sub[0];ini[1] += sub[1];}else {ini[0] -= sub[0];ini[1] -= sub[1];}}array<mint, 2> g = { seq[V - 1], seq[V - 1] };g[1] -= 1;array<mint, 2> v = { mint(0), mint(1) };long long X = M / seq[V - 1];long long rem = N - V;while (rem >= 1) {if (rem % 2 == 1) {v = merge(X, g, v);}g = merge(X, g, g);rem /= 2;}array<mint, 2> ans = merge(X, ini, g);return ans[1];}int main() {long long N, M, K;cin >> N >> M >> K;mint ans = solve(N, M, K);cout << ans.get() << endl;return 0;}