結果

問題 No.931 Multiplicative Convolution
ユーザー firiexpfiriexp
提出日時 2019-11-23 00:50:15
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 86 ms / 2,000 ms
コード長 7,582 bytes
コンパイル時間 1,309 ms
コンパイル使用メモリ 121,376 KB
実行使用メモリ 23,856 KB
最終ジャッジ日時 2024-04-27 15:49:30
合計ジャッジ時間 3,744 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,816 KB
testcase_01 AC 2 ms
6,940 KB
testcase_02 AC 2 ms
6,940 KB
testcase_03 AC 2 ms
6,940 KB
testcase_04 AC 1 ms
6,940 KB
testcase_05 AC 2 ms
6,940 KB
testcase_06 AC 2 ms
6,944 KB
testcase_07 AC 10 ms
6,940 KB
testcase_08 AC 79 ms
23,856 KB
testcase_09 AC 66 ms
23,460 KB
testcase_10 AC 77 ms
23,396 KB
testcase_11 AC 70 ms
23,308 KB
testcase_12 AC 47 ms
14,332 KB
testcase_13 AC 86 ms
23,460 KB
testcase_14 AC 75 ms
23,760 KB
testcase_15 AC 76 ms
23,720 KB
testcase_16 AC 78 ms
23,600 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <limits>
#include <iostream>
#include <algorithm>
#include <iomanip>
#include <map>
#include <set>
#include <queue>
#include <stack>
#include <numeric>
#include <bitset>
#include <cmath>

static const int MOD = 998244353;
using ll = long long;
using namespace std;

template<class T> constexpr T INF = ::numeric_limits<T>::max()/32*15+208;

#include <chrono>
class xor_shift {
    uint32_t x, y, z, w;
public:
    xor_shift() : x(static_cast<uint32_t>((chrono::system_clock::now().time_since_epoch().count())&((1LL << 32)-1))),
    y(1068246329), z(321908594), w(1234567890) {};

    uint32_t urand(){
        uint32_t t;
        t = x ^ (x << 11);
        x = y; y = z; z = w;
        w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
        return w;
    };

    int rand(int n){
        if(n < 0) return -rand(-n);
        uint32_t t = numeric_limits<uint32_t>::max()/(n+1)*(n+1);
        uint32_t e = urand();
        while(e >= t) e = urand();
        return static_cast<int>(e%(n+1));
    }

    int rand(int a, int b){
        if(a > b) swap(a, b);
        return a+rand(b-a);
    }
};

constexpr int ntt_mod = 998244353, ntt_root = 3;
// 1012924417 -> 5, 924844033 -> 5
// 998244353  -> 3, 897581057 -> 3
// 645922817  -> 3;

template<int M, int proot>
class NTT {
    vector<vector<int>> rts, rrts;
public:
    NTT() = default;

    inline int add(int a,int b){ a += b; if(a >= M) a -= M; return a; }

    inline int mul(int a,int b){ return 1LL * a * b % M; }

    inline int pow(int a,int b){
        int res = 1;
        while(b){
            if(b&1) res = mul(res, a);
            a = mul(a, a);
            b >>= 1;
        }
        return res;
    }

    inline int extgcd(int a, int b, int &x ,int &y){
        for (int u = y = 1, v = x = 0; a; ) {
            ll q = b/a;
            swap(x -= q*u, u);
            swap(y -= q*v, v);
            swap(b -= q*a, a);
        }
        return b;
    }
    inline int inv(int x){
        int s, t;
        extgcd(x, M, s, t);
        return (M+s)%M;
    }

    void ensure_base(int N) {
        if(rts.size() >= N) return;
        rts.resize(N), rrts.resize(N);
        for(int i = 1; i < N; i <<= 1) {
            if(!rts[i].empty()) continue;
            int w = pow(proot, (M - 1) / (i << 1));
            int rw = inv(w);
            rts[i].resize(i), rrts[i].resize(i);
            rts[i][0] = 1, rrts[i][0] = 1;
            for(int k = 1; k < i; k++) {
                rts[i][k] = mul(rts[i][k - 1],w);
                rrts[i][k] = mul(rrts[i][k - 1],rw);
            }
        }
    }
    void ntt(vector<int> &a, int sign){
        int n = a.size();
        ensure_base(n);

        for (int i = 0, j = 1; j < n-1; ++j) {
            for (int k = n >> 1; k > (i ^= k); k >>= 1);
            if(j < i) swap(a[i], a[j]);
        }
        for (int i = 1; i < n; i <<= 1) {
            for (int j = 0; j < n; j += i * 2) {
                for (int k = 0; k < i; ++k) {
                    int y = mul(a[j+k+i], (sign ? rrts[i][k] : rts[i][k]));
                    a[j+k+i] = add(a[j+k], M-y), a[j+k] = add(a[j+k], y) ;
                }
            }
        }
        if(sign) {
            int temp = inv(n);
            for (int i = 0; i < n; ++i) a[i] = mul(a[i],temp);
        }
    }
};



NTT<ntt_mod, ntt_root> ntt;

constexpr int M = ntt_mod;

struct poly {
    vector<int> v;
    poly() = default;
    explicit poly(int n) : v(n) {};
    explicit poly(vector<int> vv) : v(std::move(vv)) {};
    int size() const {return (int)v.size(); }
    poly cut(int len){
        if(len < v.size()) v.resize(static_cast<unsigned long>(len));
        return *this;
    }
    inline int& operator[] (int i) {return v[i]; }
    poly operator+(const poly &a) const { return poly(*this) += a; }
    poly operator-(const poly &a) const { return poly(*this) -= a; }
    poly operator*(const poly &a) const { return poly(*this) *= a; }

    poly inv() const {
        int n = size();
        vector<int> rr(1, ntt.inv(this->v[0]));
        poly r(rr);
        for (int k = 2; k <= n; k <<= 1) {
            vector<int> u(k);
            for (int i = 0; i < k; ++i) {
                u[i] = this->v[i];
            }
            poly ff(u);
            poly nr = (r*r);
            nr = nr*ff;
            nr.cut(k);
            for (int i = 0; i < k/2; ++i) {
                nr[i] = (2*r[i]-nr[i]+M)%M;
                nr[i+k/2] = (M-nr[i+k/2])%M;
            }
            r = nr;
        }
        r.v.resize(n);
        return r;
    }

    poly& operator+=(const poly &a) {
        this->v.resize(max(size(), a.size()));
        for (int i = 0; i < a.size(); ++i) {
            (this->v[i] += a.v[i]);
            if(this->v[i] > ntt_mod) this->v[i] -= M;
        }
        return *this;
    }
    poly& operator-=(const poly &a) {
        this->v.resize(max(size(), a.size()));
        for (int i = 0; i < a.size(); ++i) {
            (this->v[i] += M-a.v[i]);
            if(this->v[i] > M) this->v[i] -= M;
        }
        return *this;
    }

    poly& operator*=(poly a) {
        int N = size()+a.size()-1;
        int sz = 1;
        while(sz < N) sz <<= 1;
        ntt.ensure_base(sz);
        this->v.resize(sz); a.v.resize(sz);
        ntt.ntt(this->v, 0); ntt.ntt(a.v, 0);
        for(int i = 0; i < sz; ++i) this->v[i] = ntt.mul(this->v[i], a.v[i]);
        ntt.ntt(this->v, 1);
        this->cut(N);
        return *this;
    }

    poly& operator/=(const poly &a){
        return (*this *= a.inv());
    }
};

template <class T>
T pow_ (T x, T n, T M){
    uint64_t u = 1, xx = x;
    while (n > 0){
        if (n&1) u = u * xx % M;
        xx = xx * xx % M;
        n >>= 1;
    }
    return static_cast<T>(u);
};

vector<int> get_prime(int n){
    if(n <= 1) return vector<int>();
    vector<bool> is_prime(n+1, true);
    vector<int> prime;
    is_prime[0] = is_prime[1] = 0;
    for (int i = 2; i <= n; ++i) {
        if(is_prime[i]) prime.emplace_back(i);
        for (auto &&j : prime){
            if(i*j > n) break;
            is_prime[i*j] = false;
            if(i % j == 0) break;
        }
    }
    return prime;
}
const auto primes = get_prime(1000);

template<class T>
vector<T> prime_factor(T n){
    vector<T> res;
    for (auto &&i : primes) {
        while (n % i == 0){
            res.emplace_back(i);
            n /= i;
        }
    }
    if(n != 1) res.emplace_back(n);
    sort(res.begin(), res.end());
    res.erase(unique(res.begin(), res.end()), res.end());
    return res;
}


int main() {
    int p;
    cin >> p;
    if(p == 2){
        ll x, y;
        cin >> x >> y;
        cout << x*y%MOD << "\n";
        return 0;
    }
    xor_shift rd;
    int g = rd.rand(2, p-1);
    auto ps = prime_factor(p-1);
    while(true){
        int ok = 1;
        for (auto &&i : ps) {
            if(pow_(g, (p-1)/i, p) == 1){
                ok = false;
            }
        }
        if(ok) break;
        g = rd.rand(2, p-1);
    }
    vector<ll> gs(p*2+1, 1);
    for (int i = 1; i < 2*p; ++i) {
        gs[i] = gs[i-1]*g % p;
    }
    poly A(p-1), B(p-1);
    vector<int> a(p), b(p);
    for (int i = 1; i < p; ++i) scanf("%d", &a[i]);
    for (int i = 1; i < p; ++i) scanf("%d", &b[i]);
    for (int i = 0; i < p-1; ++i) {
        A[i] = a[gs[i]];
        B[i] = b[gs[i]];
    }
    auto C = A*B;
    vector<int> ans(p);
    for (int i = 0; i < C.size(); ++i) {
        ans[gs[i]] += C[i];
        if(ans[gs[i]] > MOD) ans[gs[i]] -= MOD;
    }
    for (int i = 0; i < p-1; ++i) {
        if(i) printf(" ");
        printf("%d", ans[i+1]);
    }
    puts("");
    return 0;
}
0