結果

問題 No.931 Multiplicative Convolution
ユーザー outlineoutline
提出日時 2021-04-02 19:10:45
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 71 ms / 2,000 ms
コード長 5,662 bytes
コンパイル時間 1,798 ms
コンパイル使用メモリ 144,640 KB
実行使用メモリ 9,360 KB
最終ジャッジ日時 2024-12-23 15:30:15
合計ジャッジ時間 4,306 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 2 ms
5,248 KB
testcase_03 AC 2 ms
5,248 KB
testcase_04 AC 2 ms
5,248 KB
testcase_05 AC 2 ms
5,248 KB
testcase_06 AC 3 ms
5,248 KB
testcase_07 AC 9 ms
5,248 KB
testcase_08 AC 70 ms
9,316 KB
testcase_09 AC 55 ms
9,032 KB
testcase_10 AC 68 ms
9,008 KB
testcase_11 AC 58 ms
9,052 KB
testcase_12 AC 60 ms
8,576 KB
testcase_13 AC 67 ms
9,284 KB
testcase_14 AC 69 ms
9,236 KB
testcase_15 AC 71 ms
9,360 KB
testcase_16 AC 69 ms
9,152 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <stack>
#include <tuple>
#include <deque>
#include <array>
#include <numeric>
#include <bitset>
#include <iomanip>
#include <cassert>
#include <chrono>
#include <random>
#include <limits>
#include <iterator>
#include <functional>
#include <sstream>
#include <fstream>
#include <complex>
#include <cstring>
#include <unordered_map>
#include <unordered_set>
using namespace std;

using ll = long long;
constexpr int INF = 1001001001;
// constexpr int mod = 1000000007;
constexpr int mod = 998244353;

template<class T>
inline bool chmax(T& x, T y){
    if(x < y){
        x = y;
        return true;
    }
    return false;
}
template<class T>
inline bool chmin(T& x, T y){
    if(x > y){
        x = y;
        return true;
    }
    return false;
}

// https://drken1215.hatenablog.com/entry/2020/11/04/155800
// https://37zigen.com/primitive-root/

long long modpow(long long x, long long n, int m){
    long long res = 1;
    while(n > 0){
        if(n & 1)   (res *= x) %= m;
        (x *= x) %= m;
        n >>= 1;
    }
    return res;
}

int calc_primitive_root(int p){
    if(p == 2)  return 1;
    if(p == 167772161)  return 3;
    if(p == 469762049)  return 3;
    if(p == 754974721)  return 11;
    if(p == 998244353)  return 3;

    // p-1 の素因数分解
    int divs[20] = {};
    divs[0] = 2;
    int cnt = 1;
    long long x = (p - 1) / 2;
    while(x % 2 == 0)   x /= 2;
    for(long long i = 3; i * i <= x; i += 2){
        if(x % i == 0){
            divs[cnt++] = i;
            while(x % i == 0)   x /= i;
        }
    }
    if(x > 1)   divs[cnt++] = x;
    
    // 原始根であるかの判定のために root^((p-1)/d) != 1 (mod p) を確かめる
    for(int root = 2;; ++root){
        bool ok = true;
        for(int i = 0; i < cnt; ++i){
            if(modpow(root, (p - 1) / divs[i], p) == 1){
                ok = false;
                break;
            }
        }
        if(ok)  return root;
    }
}

template<int mod>
struct NumberTheoreticTransform{
    vector<int> rev, rts;
    int base, max_base, root;

    NumberTheoreticTransform() : base(1), rev{0, 1}, rts{0, 1} {
        assert(mod >= 3 && mod % 2 == 1);
        auto tmp = mod - 1;
        max_base = 0;
        while(tmp % 2 == 0) tmp >>= 1, ++max_base;
        root = 2;
        while(mod_pow(root, (mod - 1) >> 1) == 1)   ++root;
        assert(mod_pow(root, mod - 1) == 1);
        root = mod_pow(root, (mod - 1) >> max_base);
    }

    inline int mod_pow(int x, int n){
        int ret = 1;
        while(n > 0){
            if(n & 1)   ret = mul(ret, x);
            x = mul(x, x);
            n >>= 1;
        }
        return ret;
    }

    inline int inverse(int x){
        return mod_pow(x, mod - 2);
    }

    inline unsigned add(unsigned x, unsigned y){
        x += y;
        if(x >= mod)    x -= mod;
        return x;
    }

    inline unsigned mul(unsigned a, unsigned b){
        return 1ull * a * b % (unsigned long long)mod;
    }

    void ensure_base(int nbase){
        if(nbase <= base)   return;
        rev.resize(1 << nbase);
        rts.resize(1 << nbase);
        for(int i = 0; i < (1 << nbase); ++i){
            rev[i] = (rev[i >> 1] >> 1) + ((i & 1) << (nbase - 1));
        }
        assert(nbase <= max_base);
        while(base < nbase){
            int z = mod_pow(root, 1 << (max_base - 1 - base));
            for(int i = 1 << (base - 1); i < (1 << base); ++i){
                rts[i << 1] = rts[i];
                rts[(i << 1) + 1] = mul(rts[i], z);
            }
            ++base;
        }
    }

    void ntt(vector<int> &a){
        const int n = (int)a.size();
        assert((n & (n - 1)) == 0);
        int zeros = __builtin_ctz(n);
        ensure_base(zeros);
        int shift = base - zeros;
        for(int i = 0; i < n; ++i){
            if(i < (rev[i] >> shift)){
                swap(a[i], a[rev[i] >> shift]);
            }
        }
        for(int k = 1; k < n; k <<= 1){
            for(int i = 0; i < n; i += 2 * k){
                for(int j = 0; j < k; ++j){
                    int z = mul(a[i + j + k], rts[j + k]);
                    a[i + j + k] = add(a[i + j], mod - z);
                    a[i + j] = add(a[i + j], z);
                }
            }
        }
    }

    vector<int> multiply(vector<int> a, vector<int> b){
        int need = a.size() + b.size() - 1;
        int nbase = 1;
        while((1 << nbase) < need)  ++nbase;
        ensure_base(nbase);
        int sz = 1 << nbase;
        a.resize(sz, 0);
        b.resize(sz, 0);
        ntt(a);
        ntt(b);
        int inv_sz = inverse(sz);
        for(int i = 0; i < sz; ++i){
            a[i] = mul(a[i], mul(b[i], inv_sz));
        }
        reverse(a.begin() + 1, a.end());
        ntt(a);
        a.resize(need);
        return a;
    }
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int P, x;
    cin >> P;

    int R = calc_primitive_root(P);
    vector<int> A(P), B(P), __log(P), powr(P, 1);
    for(int i = 1; i < P - 1; ++i){
        powr[i] = powr[i - 1] * R % P;
        __log[powr[i]] = i;
    }

    for(int i = 1; i < P; ++i)  cin >> A[__log[i]];
    for(int i = 1; i < P; ++i)  cin >> B[__log[i]];

    NumberTheoreticTransform<mod> ntt;
    auto C = ntt.multiply(A, B);

    vector<int> ans(P);
    for(int i = 0; i < (int)C.size(); ++i){
        int j = i % (P - 1);
        (ans[powr[j]] += C[i]) %= mod;
    }

    for(int i = 1; i < P; ++i){
        cout << ans[i] << (i == P - 1 ? '\n' : ' ');
    }

    return 0;
}
0