結果

問題 No.931 Multiplicative Convolution
ユーザー risujirohrisujiroh
提出日時 2019-11-22 18:58:24
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
CE  
(最新)
AC  
(最初)
実行時間 -
コード長 3,752 bytes
コンパイル時間 1,369 ms
コンパイル使用メモリ 165,024 KB
最終ジャッジ日時 2024-04-27 02:59:34
合計ジャッジ時間 2,538 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
コンパイルエラー時のメッセージ・ソースコードは、提出者また管理者しか表示できないようにしております。(リジャッジ後のコンパイルエラーは公開されます)
ただし、clay言語の場合は開発者のデバッグのため、公開されます。

コンパイルメッセージ
main.cpp:30:17: error: non-local lambda expression cannot have a capture-default
   30 | const Mint G = [&] {
      |                 ^

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

constexpr int P = 998244353;
struct Mint {
  int v;
  Mint(long long a = 0) : v((a %= P) < 0 ? a + P : a) {}
  Mint& operator*=(Mint r) { v = (long long)v * r.v % P; return *this; }
  Mint& operator/=(Mint r) { return *this *= r.inv(); }
  Mint& operator+=(Mint r) { if ((v += r.v) >= P) v -= P; return *this; }
  Mint& operator-=(Mint r) { if ((v -= r.v) < 0) v += P; return *this; }
  friend Mint operator*(Mint l, Mint r) { return l *= r; }
  friend Mint operator/(Mint l, Mint r) { return l /= r; }
  friend Mint operator+(Mint l, Mint r) { return l += r; }
  friend Mint operator-(Mint l, Mint r) { return l -= r; }
  Mint pow(long long n) const {
    Mint res = 1, a = *this;
    while (n) {
      if (n & 1) {
        res *= a;
      }
      a *= a;
      n >>= 1;
    }
    return res;
  }
  Mint inv() const { return pow(P - 2); }
};

const Mint G = [&] {
  Mint x = 1;
  while (true) {
    if (x.pow((P - 1) / 2).v != 1) {
      return x;
    }
    ++x.v;
  }
}();
void ntt(vector<Mint>& a, bool inv = false) {
  int n = a.size();
  for (int i = 1, j = 0; i < n; ++i) {
    int w = n >> 1;
    while (j >= w) {
      j -= w;
      w >>= 1;
    }
    j += w;
    if (i < j) {
      swap(a[i], a[j]);
    }
  }
  for (int w = 1; w < n; w *= 2) {
    Mint dt = G.pow((P - 1) / (2 * w));
    if (inv) {
      dt = dt.inv();
    }
    for (int s = 0; s < n; s += 2 * w) {
      Mint t = 1;
      for (int i = s, j = s + w; i < s + w; ++i, ++j) {
        Mint x = a[i], y = a[j] * t;
        a[i] = x + y;
        a[j] = x - y;
        t *= dt;
      }
    }
  }
}
vector<Mint> multiply(vector<Mint> a, vector<Mint> b) {
  int n = a.size(), m = b.size(), l = n + m - 1;
  int sz = 1 << __lg(2 * l - 1);
  a.resize(sz);
  b.resize(sz);
  ntt(a);
  ntt(b);
  for (int i = 0; i < sz; ++i) {
    a[i] *= b[i];
  }
  ntt(a, true);
  a.resize(l);
  auto inv = Mint(sz).inv();
  for (auto&& e : a) {
    e *= inv;
  }
  return a;
}

int main() {
  cin.tie(nullptr);
  ios::sync_with_stdio(false);
  string str;
  getline(cin, str);
  int p = stoi(str);
  assert(str == to_string(p));
  assert(2 <= p and p <= 99991);
  assert([&] {
    for (int i = 2; i * i <= p; ++i) {
      if (p % i == 0) {
        return false;
      }
    }
    return true;
  }());
  auto read_v = [&] {
    getline(cin, str);
    str += ' ';
    string s;
    vector<int> v;
    for (char c : str) {
      if (c == ' ') {
        v.push_back(stoi(s));
        string().swap(s);
      } else {
        s += c;
      }
    }
    assert((int)v.size() == p - 1);
    for (int e : v) {
      s += to_string(e);
      s += ' ';
      assert(0 <= e and e < P);
    }
    assert(str == s);
    return v;
  };
  vector<int> a = read_v(), b = read_v();
  assert(!getline(cin, str));

  const int g = [&] {
    int x = 1;
    while (true) {
      long long t = 1;
      int ord = 0;
      while (true) {
        t = t * x % p;
        ++ord;
        if (t == 1) {
          break;
        }
      }
      if (ord == p - 1) {
        return x;
      }
      ++x;
    }
  }();
  vector<Mint> na(p - 1), nb(p - 1);
  int gi = 1;
  for (int i = 0; i < p - 1; ++i) {
    na[i] = a[gi - 1];
    nb[i] = b[gi - 1];
    gi = (long long)gi * g % p;
  }
  auto nc = multiply(na, nb);
  for (int i = p - 1; i < (int)nc.size(); ++i) {
    nc[i - (p - 1)] += nc[i];
  }
  vector<int> c(p - 1);
  assert(gi == 1);
  for (int i = 0; i < p - 1; ++i) {
    c[gi - 1] = nc[i].v;
    gi = (long long)gi * g % p;
  }
  for (int i = 0; i < p - 1; ++i) {
    cout << c[i] << " \n"[i == p - 2];
  }
}
0