結果

問題 No.2959 Dolls' Tea Party
ユーザー vwxyzvwxyz
提出日時 2024-11-08 23:53:55
言語 C++23(gcc13)
(gcc 13.2.0 + boost 1.83.0)
結果
TLE  
実行時間 -
コード長 13,890 bytes
コンパイル時間 5,099 ms
コンパイル使用メモリ 308,284 KB
実行使用メモリ 20,096 KB
最終ジャッジ日時 2024-11-08 23:54:11
合計ジャッジ時間 15,141 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
20,096 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 2 ms
5,248 KB
testcase_03 AC 2 ms
5,248 KB
testcase_04 AC 2 ms
5,248 KB
testcase_05 AC 2 ms
5,248 KB
testcase_06 AC 1,618 ms
8,064 KB
testcase_07 AC 1,718 ms
8,384 KB
testcase_08 AC 1,684 ms
8,108 KB
testcase_09 TLE -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <vector>

using namespace std;
using uint = unsigned int;
using ll = long long;
using ull = unsigned long long;
const int MOD = 998244353;
template<class T> using V = vector<T>;
template<class T> using VV = V<V<T>>;
constexpr ll TEN(int n) { return (n == 0) ? 1 : 10 * TEN(n-1); }
#define FOR(i, a, b) for(int i=(int)(a);i<(int)(b);i++)
#define rep(i,N) for(int i=0;i<(int)(N);i++)
#define rep1(i,N) for(int i=1;i<=(int)(N);i++)
#define fs first
#define sc second
#define eb emplace_back
#define pb eb
#define all(x) x.begin(),x.end()
template<class T, class U> void chmin(T& t, const U& u) { if (t > u) t = u; }
template<class T, class U> void chmax(T& t, const U& u) { if (t < u) t = u; }
#ifdef LOCAL
#define show(x) cerr << __LINE__ << " : " << #x << " = " << (x) << endl
#else
#define show(x) true
#endif
template <class T, class U>
ostream& operator<<(ostream& os, const pair<T, U>& p) {
 return os << "P(" << p.first << ", " << p.second << ")";
}
template <class T> ostream& operator<<(ostream& os, const V<T>& v) {
 os << "[";
 for (auto d : v) os << d << ", ";
 return os << "]";
}
// cin.tie(nullptr);
// ios::sync_with_stdio(false);
// cout << fixed << setprecision(20);

template <uint MD> struct ModInt {
  using M = ModInt;
  const static M G;
  uint v;
  ModInt(ll _v = 0) { set_v(_v % MD + MD); }
  M& set_v(uint _v) {
    v = (_v < MD) ? _v : _v - MD;
    return *this;
  }
  explicit operator bool() const { return v != 0; }
  M operator-() const { return M() - *this; }
  M operator+(const M& r) const { return M().set_v(v + r.v); }
  M operator-(const M& r) const { return M().set_v(v + MD - r.v); }
  M operator*(const M& r) const { return M().set_v(ull(v) * r.v % MD); }
  M operator/(const M& r) const { return *this * r.inv(); }
  M& operator+=(const M& r) { return *this = *this + r; }
  M& operator-=(const M& r) { return *this = *this - r; }
  M& operator*=(const M& r) { return *this = *this * r; }
  M& operator/=(const M& r) { return *this = *this / r; }
  bool operator==(const M& r) const { return v == r.v; }
  M pow(ll n) const {
    M x = *this, r = 1;
      while (n) {
        if (n & 1) r *= x;
        x *= x;
        n >>= 1;
      }
      return r;
  }
  M inv() const { return pow(MD - 2); }
  friend ostream& operator<<(ostream& os, const M& r) { return os << r.v;} 
};

template<>
const ModInt<998244353> ModInt<998244353>::G = 3;  // 適切な原始根を指定します

using Mint = ModInt<998244353>;



ll modpow(ll base, ll exp, ll mod) {
  ll res = 1;
  while (exp > 0) {
    if (exp % 2 == 1) res = res * base % mod;
    base = base * base % mod;
    exp /= 2;
  }
  return res;
}
//x^2=N(mod p)となるxを返す(存在しないなら-1)
//modpow(a,b,p)はa^b(mod p)
ll Tonelli_Shanks(ll N, ll p) {
  N%=p;
  if(p==2){
    return N;
  }
  if (modpow(N, p >> 1, p) == p - 1) {
    return -1;
  } else if (p % 4 == 3) {
    return modpow(N, (p + 1) / 4, p);
  } else {
    ll n = 1;
    for (; n < p; n++) {
      if (modpow(n, p >> 1, p) == p - 1) {
        break;
      }
    }
    ll pp = p - 1;
    int c = 0;
    while (pp % 2 == 0) {
      pp /= 2;
      c++;
    }
    ll s = modpow(N, pp, p);
    ll r = modpow(N, (pp + 1) / 2, p);
    for (int i = c - 2; i >= 0; --i) {
      if (modpow(s, 1LL << i, p) == p - 1) {
        s = s * modpow(n, p >> (1 + i), p) % p;
        r = r * modpow(n, p >> (2 + i), p) % p;
      }
    }
    return r;
  }
}


void nft(bool type, V<Mint>& a) {
  int n = int(a.size()), s = 0;
  while ((1 << s) < n) s++;
  assert(1 << s == n);
  static V<Mint> ep, iep;
  while (int(ep.size()) <= s) {
    ep.push_back(Mint::G.pow(Mint(-1).v / (1 << ep.size())));
    iep.push_back(ep.back().inv());
  }
  V<Mint> b(n);
  for (int i = 1; i <= s; i++) {
    int w = 1 << (s - i);
    Mint base = type ? iep[i] : ep[i], now = 1;
    for (int y = 0; y < n / 2; y += w) {
      for (int x = 0; x < w; x++) {
        auto l = a[y << 1 | x];
        auto r = now * a[y << 1 | x | w];
        b[y | x] = l + r;
        b[y | x | n >> 1] = l - r;
      }
      now *= base;
    }
    swap(a, b);
  }
}
template <class Mint>
V<Mint> multiply(const V<Mint>& a, const V<Mint>& b) {
  int n = int(a.size()), m = int(b.size());
  if (!n || !m) return {};
  if (min(n, m) <= 8) {
    V<Mint> ans(n + m - 1);
    for (int i = 0; i < n; i++)
      for (int j = 0; j < m; j++) ans[i + j] += a[i] * b[j];
    return ans;
  }
  int lg = 0;
  while ((1 << lg) < n + m - 1) lg++;
  int z = 1 << lg;
  auto a2 = a, b2 = b;
  a2.resize(z);
  b2.resize(z);
  nft(false, a2);
  nft(false, b2);
  for (int i = 0; i < z; i++) a2[i] *= b2[i];
  nft(true, a2);
  a2.resize(n + m - 1);
  Mint iz = Mint(z).inv();
  for (int i = 0; i < n + m - 1; i++) a2[i] *= iz;
  return a2;
}

template <class D> struct Poly {
  vector<D> v;
  Poly(const vector<D>& _v = {}) : v(_v) { shrink(); }
  void shrink() {
    while (v.size() && !v.back()) v.pop_back();
  }
  int size() const { return int(v.size()); }
  D freq(int p) const { return (p < size()) ? v[p] : D(0); }

  Poly operator+(const Poly& r) const {
    auto n = max(size(), r.size());
    vector<D> res(n);
    for (int i = 0; i < n; i++) res[i] = freq(i) + r.freq(i);
    return res;
  }
  Poly operator-(const Poly& r) const {
    int n = max(size(), r.size());
    vector<D> res(n);
    for (int i = 0; i < n; i++) res[i] = freq(i) - r.freq(i);
    return res;
  }
  Poly operator*(const Poly& r) const { return {multiply(v, r.v)}; }
  Poly operator*(const D& r) const {
    int n = size();
    vector<D> res(n);
    for (int i = 0; i < n; i++) res[i] = v[i] * r;
    return res;
  }
  Poly operator/(const D &r) const{
    return *this * r.inv();
  }
  Poly operator/(const Poly& r) const {
    if (size() < r.size()) return {{}};
    int n = size() - r.size() + 1;
    return (rev().pre(n) * r.rev().inv(n)).rev(n); //変更 
  }
  Poly operator%(const Poly& r) const { return *this - *this / r * r; }
  Poly operator<<(int s) const {
    vector<D> res(size() + s);
    for (int i = 0; i < size(); i++) res[i + s] = v[i];
    return res;
  }
  Poly operator>>(int s) const {
    if (size() <= s) return Poly();
    vector<D> res(size() - s);
    for (int i = 0; i < size() - s; i++) res[i] = v[i + s];
    return res;
  }
  Poly& operator+=(const Poly& r) { return *this = *this + r; }
  Poly& operator-=(const Poly& r) { return *this = *this - r; }
  Poly& operator*=(const Poly& r) { return *this = *this * r; }
  Poly& operator*=(const D& r) { return *this = *this * r; }
  Poly& operator/=(const Poly& r) { return *this = *this / r; }
  Poly& operator/=(const D &r) {return *this = *this/r;}
  Poly& operator%=(const Poly& r) { return *this = *this % r; }
  Poly& operator<<=(const size_t& n) { return *this = *this << n; }
  Poly& operator>>=(const size_t& n) { return *this = *this >> n; }
  
  Poly pre(int le) const {
    return {{v.begin(), v.begin() + min(size(), le)}};
  }
  Poly rev(int n = -1) const {
    vector<D> res = v;
    if (n != -1) res.resize(n);
    reverse(res.begin(), res.end());
  return res;
  }
  Poly diff() const {
    vector<D> res(max(0, size() - 1));
    for (int i = 1; i < size(); i++) res[i - 1] = freq(i) * i;
    return res;
  }
  Poly inte() const {
    vector<D> res(size() + 1);
    for (int i = 0; i < size(); i++) res[i + 1] = freq(i) / (i + 1);
    return res;
  }
  // f * f.inv() = 1 + g(x)x^m
  Poly inv(int m) const {
    Poly res = Poly({D(1) / freq(0)});
    for (int i = 1; i < m; i *= 2) {
      res = (res * D(2) - res * res * pre(2 * i)).pre(2 * i);
    }
    return res.pre(m);
  }
  Poly exp(int n) const {
    assert(freq(0) == 0);
    Poly f({1}), g({1});
    for (int i = 1; i < n; i *= 2) {
      g = (g * 2 - f * g * g).pre(i);
      Poly q = diff().pre(i - 1);
      Poly w = (q + g * (f.diff() - f * q)).pre(2 * i - 1);
      f = (f + f * (*this - w.inte()).pre(2 * i)).pre(2 * i);
    }
    return f.pre(n);
  }
  Poly log(int n) const {
    assert(freq(0) == 1);
    auto f = pre(n);
    return (f.diff() * f.inv(n - 1)).pre(n - 1).inte();
  }
  Poly sqrt(int n) const {
    assert(freq(0) == 1);
    Poly f = pre(n + 1);
    Poly g({1});
    for (int i = 1; i < n; i *= 2) {
      g = (g + f.pre(2 * i) * g.inv(2 * i)) / 2;
    }
    return g.pre(n + 1);
  }
  //定数項が1である必要はない
  pair<bool,Poly> sqrt_arb(int n) const{
    if(size()==0){
      return {true,Poly()};
    }
    int c=0;
    while(c*2<size() && !freq(c*2)){
      if(freq(c*2+1)){
        return {false,Poly()};
      }
      c+=1;
    }
    //modが変わったら修正する
    Mint x=Tonelli_Shanks((ll)freq(c*2).v,998244353ll);
    if(x==-1){
      return {false,Poly()};
    }
    if(n<=c){
      return {true,Poly()};
    }
    Poly P=(*this)>>c*2;
    P/=x*x;
    P=P.sqrt(n-c);
    P<<=c;
    P*=x;
    return {true,P};
  }
  //
  Poly power(ll k,int n){
    if(!k){
      return Poly({D(1)});
    }
    if(!size()){
      return Poly();
    }
    int c=0;
    while(c<size()&&!freq(c)){
      c+=1;
    }
    if(c>(n-1)/k){
      return Poly();
    }
    Mint ic=freq(c),pc=freq(c);
    ic=ic.inv();
    pc=pc.pow(k);
    int l=n-c*k;
    return (((((*this).pre(l+c)*ic>>c).log(l)*k).exp(l)*pc)<<c*k).pre(n);
  }
  //N項目までなら%modを.pre(N)に書き換えた方が速い(けど遅い)
  Poly pow_mod(ll n, const Poly& mod) {
    Poly x = *this, r = {{1}};
    while (n) {
      if (n & 1) r = r * x % mod;
      x = x * x % mod;
      n >>= 1;
    }
    return r;
  }
  friend ostream& operator<<(ostream& os, const Poly& p) {
    if (p.size() == 0) return os << "0";
    for (auto i = 0; i < p.size(); i++) {
      if (p.v[i]) {
        os << p.v[i] << "x^" << i;
        if (i != p.size() - 1) os << "+";
      }
    }
    return os;
  }
};

template <class Mint> struct MultiEval {
  using NP = MultiEval*;
  NP l, r;
  vector<Mint> que;
  int sz;
  Poly<Mint> mul;
  MultiEval(const vector<Mint>& _que, int off, int _sz) : sz(_sz) {
    if (sz <= 100) {
      que = {_que.begin() + off, _que.begin() + off + sz};
      mul = {{1}};
      for (auto x : que) mul *= {{-x, 1}};
      return;
    }
    l = new MultiEval(_que, off, sz / 2);
    r = new MultiEval(_que, off + sz / 2, sz - sz / 2);
    mul = l->mul * r->mul;
  }
  MultiEval(const vector<Mint>& _que) : MultiEval(_que, 0, int(_que.size())) {}
  void query(const Poly<Mint>& _pol, vector<Mint>& res) const {
    if (sz <= 100) {
      for (auto x : que) {
        Mint sm = 0, base = 1;
        for (int i = 0; i < _pol.size(); i++) {
          sm += base * _pol.freq(i);
          base *= x;
        }
        res.push_back(sm);
      }
      return;
    }
    auto pol = _pol % mul;
    l->query(pol, res);
    r->query(pol, res);
  }
  vector<Mint> query(const Poly<Mint>& pol) const {
    vector<Mint> res;
    query(pol, res);
    return res;
  }
};

//rev()を取って-1倍すると、1+...の形で線形漸化式の分母になる
template <class Mint> Poly<Mint> berlekamp_massey(const vector<Mint>& s) {
  int n = int(s.size());
  vector<Mint> b = {Mint(-1)}, c = {Mint(-1)};
  Mint y = Mint(1);
  for (int ed = 1; ed <= n; ed++) {
    int l = int(c.size()), m = int(b.size());
    Mint x = 0;
    for (int i = 0; i < l; i++) {
      x += c[i] * s[ed - l + i];
    }
    b.push_back(0);
    m++;
    if (!x) continue;
    Mint freq = x / y;
    if (l < m) {
      // use b
      auto tmp = c;
      c.insert(begin(c), m - l, Mint(0));
      for (int i = 0; i < m; i++) {
        c[m - 1 - i] -= freq * b[m - 1 - i];
      }
      b = tmp;
      y = x;
    } else {
      // use c
      for (int i = 0; i < m; i++) {
        c[l - 1 - i] -= freq * b[m - 1 - i];
      }
    }
  }
  return c;
}

// n/dのx^Nの係数
template <class D>
D Bostan_Mori(Poly<D> n, Poly<D> d, ll N) {
  while (N) {
    Poly<D> dd=Poly(d.v);
    for(int i=1;i<dd.size();i+=2){
      dd.v[i]=-dd.v[i];
    }
    n*=dd;
    if(N%2) n>>=1;
    for(int i=0;i<n.size();i+=2){
      n.v[i/2]=n.v[i];
    }
    n=n.pre((n.size()+1)/2);
    d*=dd;
    for(int i=0;i<d.size();i+=2){
      d.v[i/2]=d.v[i];
    }
    d=d.pre((d.size()+1)/2);
    N>>=1;
  }
  return n.size() ? n.freq(0) : D(0);
}

template<class D>
D BMBM(vector<D> A,ll N){
  Poly<D> d=berlekamp_massey(A).rev()*(-1);
  Poly<D> n=(d*Poly(A)).pre(d.size()-1);
  return Bostan_Mori(n,d,N);
}


int K;
vector<Mint> fact, fact_inve;

// メモ化ログ関数
unordered_map<int, Poly<Mint>> memo_log;
Poly<Mint> memoized_log(const int C) {
    if (memo_log.count(C)) return memo_log[C];
    vector<Mint> v(K+1,0);
    for (int x = 0; x <= K; ++x) {
        if (x <= C) v[x] = fact_inve[x];
    }
    Poly<Mint> poly(v);
    return memo_log[C] = poly.log(K+1);
}

// solve関数
Mint solve(int C, vector<int> A, int N) {
    vector<ll> dp(C + 1, 0);
    dp[0] = 1;
    int S = K / C;
    vector<int> cnt(C + 2, 0);

    for (int i = 0; i < N; ++i) {
        A[i] /= S;
        A[i] = min(A[i], C);
        if (A[i]) cnt[A[i]]++;
    }
    vector<Mint> poly_log(C + 1, 0);
    for (int c = 1; c <= C; ++c) {
        if (cnt[c]) {
            Poly<Mint> log_poly = memoized_log(min(C,c));
            for (int x = 0; x <= C; ++x) {
                poly_log[x] = (poly_log[x] + log_poly.freq(x) * cnt[c]);
            }
        }
    }
    Poly<Mint> poly(poly_log);
    poly=poly.exp(C+1);
    return poly.freq(C) * fact[C];
}

int main() {
    int N;
    cin >> N >> K;

    vector<int> A(N);
    for(int i=0;i<N;i++){
      cin>>A[i];
    }

    fact.resize(N + K + 1, 1);
    fact_inve.resize(N + K + 1);

    for (int i = 1; i <= N + K; ++i) {
        fact[i] = fact[i - 1] * Mint(i);
    }
    for (int i = 0; i <= N + K; ++i) {
        fact_inve[i] = Mint(1)/fact[i];
    }

    Mint ans = 0;
    for (int d = 0; d < K; ++d) {
        ans = ans + solve(gcd(K, d), A, N);
    }

    ans = ans * modpow(K, MOD - 2, MOD);
    cout << ans << endl;
    return 0;
}
0