結果

問題 No.2763 Macaron Gift Box
ユーザー PNJPNJ
提出日時 2024-11-21 03:03:42
言語 C++23
(gcc 12.3.0 + boost 1.83.0)
結果
RE  
実行時間 -
コード長 15,206 bytes
コンパイル時間 3,601 ms
コンパイル使用メモリ 266,268 KB
実行使用メモリ 12,792 KB
最終ジャッジ日時 2024-11-21 03:03:52
合計ジャッジ時間 6,344 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,816 KB
testcase_01 AC 2 ms
6,816 KB
testcase_02 AC 2 ms
6,816 KB
testcase_03 AC 2 ms
6,816 KB
testcase_04 AC 2 ms
6,820 KB
testcase_05 AC 2 ms
6,820 KB
testcase_06 AC 2 ms
6,816 KB
testcase_07 AC 142 ms
7,828 KB
testcase_08 AC 35 ms
6,816 KB
testcase_09 AC 72 ms
6,816 KB
testcase_10 RE -
testcase_11 RE -
testcase_12 AC 474 ms
12,792 KB
testcase_13 RE -
testcase_14 AC 34 ms
6,820 KB
testcase_15 AC 34 ms
6,816 KB
testcase_16 AC 35 ms
6,820 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

template <class T>
using vc = vector<T>;
template <class T>
using vvc = vector<vc<T>>;
template <class T>
using vvvc = vector<vvc<T>>;
template <class T>
using vvvvc = vector<vvvc<T>>;
template <class T>
using vvvvvc = vector<vvvvc<T>>;

#define elif else if

#define FOR1(a) for (ll _ = 0; _ < ll(a); _++)
#define FOR2(i,n) for (ll i = 0; i < ll(n); i++)
#define FOR3(i,l,r) for (ll i = l; i < ll(r); i++)
#define FOR4(i,l,r,c) for (ll i = l; i < ll(r); i += (c))
#define FOR1_R(a) for (ll _ = ll(a) - 1; _ >= 0; _--)
#define FOR2_R(i,n) for (ll i = (n) - 1; i >= ll(0); i--)
#define FOR3_R(i,l,r) for (ll i = (r) - 1; i >= ll(l); i--)
#define FOR4_R(i,l,r,c) for (ll i = (r) - 1; i >= ll(l); i -= (c))
#define overload4(a, b, c, d, e, ...) e
#define FOR(...) overload4(__VA_ARGS__, FOR4, FOR3, FOR2, FOR1)(__VA_ARGS__)
#define FOR_R(...) overload4(__VA_ARGS__, FOR4_R, FOR3_R, FOR2_R, FOR1_R)(__VA_ARGS__)
#define FOR_each(a,A) for (auto &&a: A)
#define FOR_subset(t,s) for(ll t = (s); t >= 0; t = (t == 0 ? -1 : (t - 1) & (s)))

#define all(x) x.begin(), x.end()
#define len(x) ll(x.size())

int popcnt(int x) { return __builtin_popcount(x); }
int popcnt(uint32_t x) { return __builtin_popcount(x); }
int popcnt(long long x) { return __builtin_popcountll(x); }
int popcnt(uint64_t x) { return __builtin_popcountll(x); }
int topbit(int x) { return (x == 0 ? -1 : 31 - __builtin_clz(x)); }
int topbit(uint32_t x) { return (x == 0 ? -1 : 31 - __builtin_clz(x)); }
int topbit(long long x) { return (x == 0 ? -1 : 63 - __builtin_clzll(x)); }
int topbit(uint64_t x) { return (x == 0 ? -1 : 63 - __builtin_clzll(x)); }

// 入力
void rd() {}
void rd(char& c){ cin >> c; }
void rd(string& s){ cin >> s; }
void rd(int& x){ cin >> x; }
void rd(uint32_t& x){ cin >> x; }
void rd(long long& x){ cin >> x; }
void rd(uint64_t& x){ cin >> x; }
template<class T>
void rd(vector<T> & v){
  for (auto& x:v) rd(x);
}

void read() {}
template <class H, class... T>
void read(H& h, T&... t) {
  rd(h), read(t...);
}

#define CHAR(...) \
  char __VA_ARGS__; \
  read(__VA_ARGS__)

#define STRING(...) \
  string __VA_ARGS__; \
  read(__VA_ARGS__)

#define INT(...) \
  int __VA_ARGS__; \
  read(__VA_ARGS__)

#define U32(...) \
  uint32_t __VA_ARGS__; \
  read(__VA_ARGS__)

#define LL(...) \
  long long __VA_ARGS__; \
  read(__VA_ARGS__)

#define U64(...) \
  uint64_t __VA_ARGS__; \
  read(__VA_ARGS__)

#define VC(t, a, n) \
  vector<t> a(n); \
  read(a)

#define VVC(t, a, h, w) \
  vector<vector<t>> a(h, vector<t>(w)); \
  read(a)

//出力
void wt() {}
void wt(const char c){ cout << c; }
void wt(const string s){ cout << s; }
void wt(int x){ cout << x; }
void wt(uint32_t x) { cout << x; }
void wt(long long x){ cout << x; }
void wt(uint64_t x) { cout << x; }
template<class T>
void wt(const vector<T> v){
  int n = v.size();
  for (int i = 0; i < n; i++){
    if (i) wt(' ');
    wt(v[i]);
  }
}

void print() { wt('\n'); }
template <class Head, class... Tail>
void print(Head&& head, Tail&&... tail) {
  wt(head);
  if (sizeof...(Tail)) wt(' ');
  print(forward<Tail>(tail)...);
}

template <int mod>
struct modint {
  static constexpr uint32_t umod = uint32_t(mod);
  static_assert(umod < (uint32_t(1) << 31));
  uint32_t val;

  static modint raw(uint32_t v){
    modint x;
    x.val = v % mod;
    return x;
  }

  constexpr modint() : val(0) {}
  constexpr modint(uint32_t x) : val(x % umod) {}
  constexpr modint(uint64_t x) : val(x % umod) {}
  constexpr modint(unsigned __int128 x) : val(x % umod) {}
  constexpr modint(int x) : val((x %= mod) < 0 ? x + mod : x){};
  constexpr modint(long long x) : val((x %= mod) < 0 ? x + mod : x){};
  constexpr modint(__int128 x) : val((x %= mod) < 0 ? x + mod : x){};

  bool operator<(const modint &other) const { return val < other.val; }
  modint &operator+=(const modint &p) {
    if ((val += p.val) >= umod) val -= umod;
    return *this;
  }
  modint &operator-=(const modint &p) {
    if ((val += umod - p.val) >= umod) val -= umod;
    return *this;
  }
  modint &operator*=(const modint &p) {
    val = uint64_t(val) * p.val % umod;
    return *this;
  }
  modint &operator/=(const modint &p) {
    *this *= p.inverse();
    return *this;
  }
  modint operator-() const { return modint::raw(val ? mod - val : uint32_t(0)); }
  modint operator+(const modint &p) const { return modint(*this) += p; }
  modint operator-(const modint &p) const { return modint(*this) -= p; }
  modint operator*(const modint &p) const { return modint(*this) *= p; }
  modint operator/(const modint &p) const { return modint(*this) /= p; }
  bool operator==(const modint &p) const { return val == p.val; }
  bool operator!=(const modint &p) const { return val != p.val; }

  modint inverse() const {
    int a = val, b = mod, u = 1, v = 0, t;
    while (b) {
      t = a / b;
      a -= t * b;
      swap(a,b);
      u -= t * v;
      swap(u,v);
    }
    return modint(u);
  }

  modint pow(long long n) const {
    assert(n >= 0);
    modint res(1), a(val);
    while (n > 0) {
      if (n & 1) res *= a;
      a *= a;
      n >>= 1;
    }
    return res;
  }

  static constexpr int get_mod() { return mod; }

  static constexpr pair<int,int> ntt_info() {
    if (mod == 120586241) return {20,74066978};
    if (mod == 167772161) return {25,17};
    if (mod == 469762049) return {26,30};
    if (mod == 754974721) return {24,362};
    if (mod == 880803841) return {23,211};
    if (mod == 924844033) return {21,44009197};
    if (mod == 943718401) return {22,663003469};
    if (mod == 998244353) return {23,31};
    if (mod == 1045430273) return {20,363};
    if (mod == 1051721729) return {20,330};
    if (mod == 1053818881) return {20,2789};
    return {-1,-1};
  }
};

template <int mod>
void rd(modint<mod>& x){
  uint32_t y;
  cin >> y;
  x = y;
}

template <int mod>
void wt(modint<mod> x){
  wt(x.val);
}

using mint = modint<998244353>;

mint fact(int n) {
  static vector<mint> res = {1, 1};
  static int le = 1;
  while (le <= n){
    le++;
    res.push_back(res[le - 1] * le);
  }
  return res[n];
}

mint fact_inv(int n) {
  static vector<mint> res = {1, 1};
  static int le = 1;
  while (le <= n){
    le++;
    res.push_back(res[le - 1] / le);
  }
  return res[n];
}

mint binom(int n, int r){
  if (n < r) return 0;
  if (min(n,r) < 0) return 0;
  mint res = fact(n) * (fact_inv(n - r) * fact_inv(r));
  return res;
}

template <class mint>
void ntt(vector<mint> &a, bool inverse) {
  const int rank2 = mint::ntt_info().first;
  const int mod = mint::get_mod();
  static array<mint, 30> root, rate2, rate3, iroot, irate2, irate3;

  static bool prepared = 0;
  if (!prepared){
    prepared = 1;

    root[rank2] = mint::ntt_info().second;
    iroot[rank2] = mint(1) / root[rank2];
    for (int i = rank2 - 1; i >= 0; i--){
      root[i] = root[i + 1] * root[i + 1];
      iroot[i] = iroot[i + 1] * iroot[i + 1];
    }

    mint prod = 1, iprod = 1;
    for (int i = 0; i < rank2; i++){
      rate2[i] = root[i + 2] * prod;
      irate2[i] = iroot[i + 2] * iprod;
      prod *= iroot[i + 2];
      iprod *= root[i + 2];
    }

    prod = 1, iprod = 1;
    for (int i = 0; i < rank2 - 1; i++){
      rate3[i] = root[i + 3] * prod;
      irate3[i] = iroot[i + 3] * iprod;
      prod *= iroot[i + 3];
      iprod *= root[i + 3];
    }
  }

  int n = a.size();
  int h = topbit(n);
  if (!inverse) {
    int le = 0;
    while (le < h){
      if (h - le == 1){
        int p = 1 << (h - le - 1);
        mint rot = 1;
        for (int s = 0; s < (1 << le); s++){
          int offset = s << (h - le);
          for (int i = 0; i < p; i++){
            auto l = a[i + offset];
            auto r = a[i + offset + p] * rot;
            a[i + offset] = l + r;
            a[i + offset + p] = l - r;
          }
          rot *= rate2[topbit(~s & -~s)];
        }
        le++;
      }
      else{
        int p = 1 << (h - le - 2);
        mint rot = 1, imag = root[2];
        for (int s = 0; s < (1 << le); s++){
          mint rot2 = rot * rot;
          mint rot3 = rot2 * rot;
          int offset = s << (h - le);
          for (int i = 0; i < p; i++){
            uint64_t mod2 = uint64_t(mod) * mod;
            uint64_t a0 = a[i + offset].val;
            uint64_t a1 = uint64_t(a[i + offset + p].val) * rot.val;
            uint64_t a2 = uint64_t(a[i + offset + p * 2].val) * rot2.val;
            uint64_t a3 = uint64_t(a[i + offset + p * 3].val) * rot3.val;
            uint64_t a1na3imag = (a1 + mod2 - a3) % mod * imag.val;
            a[i + offset] = a0 + a2 + a1 + a3;
            a[i + offset + p] = a0 + a2 + (2 * mod2 - (a1 + a3));
            a[i + offset + p * 2] = a0 + mod2 - a2 + a1na3imag;
            a[i + offset + p * 3] = a0 + mod2 - a2 + (mod2 - a1na3imag);
          }
          rot = rot * rate3[topbit(~s & -~s)];
        }
        le = le + 2;
      }
    }
  }
  else{
    mint coef = mint(1) / mint(n);
    for (int i = 0; i < n; i++){
      a[i] *= coef;
    }
    int le = h;
    while (le){
      if (le == 1){
        int p = 1 << (h - le);
        mint irot = 1;
        for (int s = 0; s < (1 << (le - 1)); s++){
          int offset = s << (h - le + 1);
          for (int i = 0; i < p; i++){
            uint64_t l = a[i + offset].val;
            uint64_t r = a[i + offset + p].val;
            a[i + offset] = l + r;
            a[i + offset + p] = (mod + l - r) * irot.val;
          }
          irot *= irate2[topbit(~s & -~s)];
          }
        le--;
      }
      else{
        int p = 1 << (h - le);
        mint irot = 1, iimag = iroot[2];
        for (int s = 0; s < (1 << (le - 2)); s++){
          mint irot2 = irot * irot;
          mint irot3 = irot2 * irot;
          int offset = s << (h - le + 2);
          for (int i = 0; i < p; i++){
            uint64_t a0 = a[i + offset].val;
            uint64_t a1 = a[i + offset + p].val;
            uint64_t a2 = a[i + offset + p * 2].val;
            uint64_t a3 = a[i + offset + p * 3].val;
            uint64_t a2na3iimag = (mod + a2 - a3) * iimag.val % mod;
            a[i + offset] = a0 + a1 + a2 + a3;
            a[i + offset + p] = (a0 + mod - a1 + a2na3iimag) * irot.val;
            a[i + offset + p * 2] = (a0 + a1 + 2 * mod - a2 - a3) * irot2.val;
            a[i + offset + p * 3] = (a0 + 2 * mod - a1 - a2na3iimag) * irot3.val;
          }
          irot *= irate3[topbit(~s & -~s)];
        }
        le = le - 2;
      }
    }
  }
}

template <class mint>
vector<mint> convolute_naive(vector<mint> a, vector<mint> b){
  vector<mint> res(size(a) + size(b) - 1);
  for (int i = 0; i < size(a); i++){
    for (int j = 0; j < size(b); j ++){
      res[i + j] = res[i + j] + a[i] * b[j];
    }
  }
  return res;
}

template <class mint>
vector<mint> convolute(vector<mint> a, vector<mint> b){
  int n = a.size();
  int m = b.size();
  if (min(n,m) <= 60) return convolute_naive(a,b);
  int le = 1;
  while (le < n + m - 1) le = le * 2;
  a.resize(le), b.resize(le);
  ntt(a, 0),ntt(b, 0);
  for (int i = 0; i < le; i++) a[i] *= b[i];
  ntt(a, 1);
  a.resize(n + m - 1);
  return a;
}

template <class mint>
vector<mint> fps_inv(vector<mint> f, int deg = -1){
  assert (f[0] != 0);
  if (deg == -1) deg = int(f.size());
  vector<mint> res(deg);
  res[0] = f[0].inverse();
  int d = 1;
  while (d < deg){
    vector<mint> a(d * 2);
    for (int i = 0; i < min(int(f.size()), d * 2); i++) a[i] = f[i];
    vector<mint> b(d * 2);
    for (int i = 0; i < d; i++) b[i] = res[i];
    ntt(a, 0), ntt(b, 0);
    for (int i = 0; i < d * 2; i++) a[i] *= b[i];
    ntt(a, 1);
    for (int i = 0; i < d; i++) a[i] = 0;
    ntt(a, 0);
    for (int i = 0; i < d * 2; i++) a[i] *= b[i];
    ntt(a, 1);
    for (int j = d; j < min(d * 2, deg); j++){
      if (a[j].val > 0) res[j] = -a[j];
      else res[j] = 0;
    }
    d *= 2;
  }
  return res;
}

template <class mint>
vector<mint> fps_diff(vector<mint> f){
  if (f.size() <= 1) return {mint(0)};
  vector<mint> res;
  for (int i = 1; i < int(f.size()); i++) res.push_back(f[i] * i);
  return res;
}

template <class mint>
vector<mint> fps_integrate(vector<mint> f){
  int n = int(f.size());
  vector<mint> res(n + 1);
  for (int i = 0; i < n; i++){
    res[i + 1] = f[i] / (i + 1);
  }
  return res;
}

template <class mint>
vector<mint> fps_log(vector<mint> f, int deg = -1){
  assert (f[0] == 1);
  if (deg == -1) deg = int(f.size());
  vc<mint> res = convolute(fps_diff(f), fps_inv(f, deg));
  res = fps_integrate(res);
  res.resize(deg);
  return res;
}

template <class mint>
vector<mint> fps_exp(vector<mint> f, int deg = -1){
  assert (f[0] == 0);
  if (deg == -1) deg = int(f.size());
  vector<mint> res = {1, 0};
  if (f.size() > 1) res[1] = f[1];
  vector<mint> g, p, q;
  g = {1}, q = {1, 1};
  int m = 2;
  while (m < deg){
    vector<mint> y = res;
    y.resize(int(res.size()) + m);
    ntt(y, 0);
    p = q;
    vector<mint> z(int(p.size()));
    for (int i = 0; i < int(p.size()); i++) z[i] = y[i] * p[i];
    ntt(z, 1);
    for (int i = 0; i < m / 2; i++) z[i] = 0;
    ntt(z, 0);
    for (int i = 0; i < int(p.size()); i++) z[i] *= -p[i];
    ntt(z, 1);
    for (int i = m / 2; i < int(z.size()); i++) g.push_back(z[i]);
    q = g;
    q.resize(int(q.size()) + m);
    ntt(q, 0);
    vector<mint> x(m);
    for (int i = 0; i < min(int(f.size()), m); i++) x[i] = f[i];
    x = fps_diff(x);
    x.push_back(0);
    ntt(x, 0);
    for (int i = 0; i < int(x.size()); i++) x[i] *= y[i];
    ntt(x, 1);
    for (int i = 1; i < int(res.size()); i++) x[i - 1] -= res[i] * i;
    for (int i = 0; i < m; i++) x.push_back(0);
    for (int i = 0; i < m - 1; i++){
      swap(x[m + i], x[i]);
      x[i] = 0;
    }
    ntt(x, 0);
    for (int i = 0; i < int(q.size()); i++) x[i] *= q[i];
    ntt(x, 1);
    x.pop_back();
    x = fps_integrate(x);
    for (int i = 0; i < m; i++) x[i] = 0;
    for (int i = m; i < min(int(f.size()), m * 2); i++) x[i] += f[i];
    ntt(x, 0);
    for (int i = 0; i < int(y.size()); i++) x[i] *= y[i];
    ntt(x, 1);
    for (int i = m; i < int(x.size()); i++) res.push_back(x[i]);
    m *= 2;
  }
  res.resize(deg);
  return res;
}

template <class mint>
vector<mint> fps_pow(vector<mint> f, int k, int deg = -1){
  if (deg == -1) deg = f.size();
  if (k == 0){
    vector<mint> res(deg);
    res[0] = 1;
    return res;
  }
  f.resize(deg);
  int p = 0;
  while (p < deg){
    if (f[p].val) break;
    p++;
  }
  if (p * k >= deg){
    vector<mint> res(deg);
    return res;
  }
  mint a = f[p];
  vector<mint> g(deg - p);
  for (int i = 0; i < deg - p; i++) g[i] = f[i + p] / a;
  g = fps_log(g);
  for (int i = 0; i < deg - p; i++) g[i] *= k;
  g = fps_exp(g);
  a = a.pow(k);
  vector<mint> res(deg);
  for (int i = 0; i < deg; i++){
    int j = i + p * k;
    if (j >= deg) break;
    res[j] = g[i] * a;
  }
  return res;
}

int main(){
  INT(N, K);
  vc<mint> L(N + 1);
  FOR(i, 1, N + 1){
    FOR(j, 1, N + 1){
      if (i * j > N) break;
      L[i * j] += mint(j).inverse();
    }
    int k = (K + 1) * i;
    FOR(j, 1, N + 1){
      if (k * j > N) break;
      L[k * j] -= mint(j).inverse();
    }
  }
  vc<mint> f = fps_exp(L);
  vc<mint> g;
  FOR(i, 1, f.size()) g.push_back(f[i]);
  print(g);
}
0