結果

問題 No.1145 Sums of Powers
ユーザー 👑 hos.lyrichos.lyric
提出日時 2020-07-31 22:52:41
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 6,942 bytes
コンパイル時間 4,917 ms
コンパイル使用メモリ 271,248 KB
実行使用メモリ 7,392 KB
最終ジャッジ日時 2024-07-06 20:20:16
合計ジャッジ時間 6,407 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 WA -
testcase_01 WA -
testcase_02 WA -
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <functional>
#include <iostream>
#include <map>
#include <numeric>
#include <queue>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

using namespace std;

using Int = long long;

template <class T1, class T2> ostream &operator<<(ostream &os, const pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; };
template <class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cerr << *i << " "; cerr << endl; }
template <class T> bool chmin(T &t, const T &f) { if (t > f) { t = f; return true; } return false; }
template <class T> bool chmax(T &t, const T &f) { if (t < f) { t = f; return true; } return false; }


template<int M_> struct ModInt {
  static constexpr int M = M_;
  int x;
  constexpr ModInt() : x(0) {}
  constexpr ModInt(long long x_) : x(x_ % M) { if (x < 0) x += M; }
  ModInt &operator+=(const ModInt &a) { x += a.x; if (x >= M) x -= M; return *this; }
  ModInt &operator-=(const ModInt &a) { x -= a.x; if (x < 0) x += M; return *this; }
  ModInt &operator*=(const ModInt &a) { x = static_cast<int>((static_cast<long long>(x) * a.x) % M); return *this; }
  ModInt &operator/=(const ModInt &a) { return (*this *= a.inv()); }
  ModInt operator+(const ModInt &a) const { return (ModInt(*this) += a); }
  ModInt operator-(const ModInt &a) const { return (ModInt(*this) -= a); }
  ModInt operator*(const ModInt &a) const { return (ModInt(*this) *= a); }
  ModInt operator/(const ModInt &a) const { return (ModInt(*this) /= a); }
  ModInt operator-() const { return ModInt(-x); }
  ModInt pow(long long e) const {
    ModInt x2 = x, xe = 1;
    for (; e; e >>= 1) {
      if (e & 1) xe *= x2;
      x2 *= x2;
    }
    return xe;
  }
  ModInt inv() const {
    int a = x, b = M, y = 1, z = 0, t;
    for (; ; ) {
      t = a / b; a -= t * b;
      if (a == 0) {
        assert(b == 1 || b == -1);
        return ModInt(b * z);
      }
      y -= t * z;
      t = b / a; b -= t * a;
      if (b == 0) {
        assert(a == 1 || a == -1);
        return ModInt(a * y);
      }
      z -= t * y;
    }
  }
  friend ModInt operator+(long long a, const ModInt &b) { return (ModInt(a) += b); }
  friend ModInt operator-(long long a, const ModInt &b) { return (ModInt(a) -= b); }
  friend ModInt operator*(long long a, const ModInt &b) { return (ModInt(a) *= b); }
  friend std::ostream &operator<<(std::ostream &os, const ModInt &a) { return os << a.x; }
};

constexpr int MO = 998244353;
using Mint = ModInt<MO>;


// M: prime, G: primitive root
template <int M, int G, int K> struct Fft {
  // 1, 1/4, 1/8, 3/8, 1/16, 5/16, 3/16, 7/16, ...
  int gs[1 << (K - 1)];
  constexpr Fft() : gs() {
    static_assert(2 <= K && K <= 30, "Fft: 2 <= K <= 30 must hold");
    static_assert(!((M - 1) & ((1 << K) - 1)), "Fft: 2^K | M - 1 must hold");
    gs[0] = 1;
    long long g2 = G, gg = 1;
    for (int e = (M - 1) >> K; e; e >>= 1) {
      if (e & 1) gg = (gg * g2) % M;
      g2 = (g2 * g2) % M;
    }
    gs[1 << (K - 2)] = gg;
    for (int l = 1 << (K - 2); l >= 2; l >>= 1) {
      gs[l >> 1] = (static_cast<long long>(gs[l]) * gs[l]) % M;
    }
    assert((static_cast<long long>(gs[1]) * gs[1]) % M == M - 1);
    for (int l = 2; l <= 1 << (K - 2); l <<= 1) {
      for (int i = 1; i < l; ++i) {
        gs[l + i] = (static_cast<long long>(gs[l]) * gs[i]) % M;
      }
    }
  }
  void fft(vector<int> &xs) const {
    const int n = xs.size();
    assert(!(n & (n - 1)) && n <= 1 << K);
    for (int l = n; l >>= 1; ) {
      for (int i = 0; i < (n >> 1) / l; ++i) {
        const long long g = gs[i];
        for (int j = (i << 1) * l; j < (i << 1 | 1) * l; ++j) {
          const int t = (g * xs[j + l]) % M;
          if ((xs[j + l] = xs[j] - t) < 0) xs[j + l] += M;
          if ((xs[j] += t) >= M) xs[j] -= M;
        }
      }
    }
  }
  void invFft(vector<int> &xs) const {
    const int n = xs.size();
    assert(!(n & (n - 1)) && n <= 1 << K);
    for (int l = 1; l < n; l <<= 1) {
      std::reverse(xs.begin() + l, xs.begin() + (l << 1));
    }
    for (int l = 1; l < n; l <<= 1) {
      for (int i = 0; i < (n >> 1) / l; ++i) {
        const long long g = gs[i];
        for (int j = (i << 1) * l; j < (i << 1 | 1) * l; ++j) {
          int t = (g * (xs[j] - xs[j + l])) % M;
          if (t < 0) t += M;
          if ((xs[j] += xs[j + l]) >= M) xs[j] -= M;
          xs[j + l] = t;
        }
      }
    }
  }
  template<class T>
  vector<T> convolute(const vector<T> &as, const vector<T> &bs) const {
    const int na = as.size(), nb = bs.size();
    int n, invN = 1;
    for (n = 1; n < na + nb - 1; n <<= 1) {
      invN = ((invN & 1) ? (invN + M) : invN) >> 1;
    }
    vector<int> xs(n, 0), ys(n, 0);
    for (int i = 0; i < na; ++i) if ((xs[i] = as[i] % M) < 0) xs[i] += M;
    for (int i = 0; i < nb; ++i) if ((ys[i] = bs[i] % M) < 0) ys[i] += M;
    fft(xs);
    fft(ys);
    for (int i = 0; i < n; ++i) {
      xs[i] = (((static_cast<long long>(xs[i]) * ys[i]) % M) * invN) % M;
    }
    invFft(xs);
    xs.resize(na + nb - 1);
    return xs;
  }
  vector<Mint> convolute(const vector<Mint> &as, const vector<Mint> &bs) const {
    const int na = as.size(), nb = bs.size();
    int n, invN = 1;
    for (n = 1; n < na + nb - 1; n <<= 1) {
      invN = ((invN & 1) ? (invN + M) : invN) >> 1;
    }
    vector<int> xs(n, 0), ys(n, 0);
    for (int i = 0; i < na; ++i) xs[i] = as[i].x;
    for (int i = 0; i < nb; ++i) ys[i] = bs[i].x;
    fft(xs);
    fft(ys);
    for (int i = 0; i < n; ++i) {
      xs[i] = (((static_cast<long long>(xs[i]) * ys[i]) % M) * invN) % M;
    }
    invFft(xs);
    vector<Mint> ret(na + nb - 1);
    for (int i = 0; i < na + nb - 1; ++i) ret[i].x = xs[i];
    return ret;
  }
};

const Fft<998244353, 3, 20> FFT;


int N, M;
vector<int> A;

pair<vector<Mint>, vector<Mint>> solve(int l, int r) {
  if (r - l == 1) {
    return make_pair(vector<Mint>{1, -A[l]}, vector<Mint>{1});
  } else {
    const int mid = (l + r) / 2;
    const auto resL = solve(l, mid);
    const auto resR = solve(mid, r);
    const auto ret0 = FFT.convolute(resL.first, resR.first);
    auto ret1 = FFT.convolute(resL.second, resR.first);
    const auto tmp = FFT.convolute(resL.first, resR.second);
    for (int i = 0; i < r - l; ++i) {
      ret1[i] += tmp[i];
    }
    return make_pair(ret0, ret1);
  }
}

int main() {
  for (; ~scanf("%d%d", &N, &M); ) {
    A.resize(N);
    for (int i = 0; i < N; ++i) {
      scanf("%d", &A[i]);
    }
    
    const auto res = solve(0, N);
// cerr<<"res[0] = ";pv(res.first.begin(),res.first.end());
// cerr<<"res[1] = ";pv(res.second.begin(),res.second.end());
    // const auto ans = FFT.convolute(inv(res.first, M + 1), res.second);
    
  }
  return 0;
}
0