結果

問題 No.1145 Sums of Powers
ユーザー ace_amuroace_amuro
提出日時 2021-01-13 19:32:00
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 1,093 ms / 2,000 ms
コード長 6,862 bytes
コンパイル時間 814 ms
コンパイル使用メモリ 83,596 KB
実行使用メモリ 45,612 KB
最終ジャッジ日時 2024-05-02 04:30:31
合計ジャッジ時間 4,985 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 5 ms
9,088 KB
testcase_01 AC 5 ms
9,088 KB
testcase_02 AC 9 ms
9,344 KB
testcase_03 AC 1,072 ms
45,612 KB
testcase_04 AC 1,093 ms
45,496 KB
testcase_05 AC 1,069 ms
45,508 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<vector>
#include<algorithm>
using namespace std;
typedef long long LL;
typedef unsigned int uint;
LL ww[101];
LL* e = ww + 50;
const int MAXN = 1e5 + 5;
const LL P = 998244353;

LL qpow(LL a, LL k, LL p) {
    LL c = 1;
    a %= p;
    while (k) {
        if (k & 1) c = (c * a) % p;
        k >>= 1;
        a = (a * a) % p;
    }
    return c;
}

void primeroot(LL* e, const LL& P, const LL& g) {
    int s = 0; LL q = P - 1;             //`将$p$分解成$p=q\times2^s+1$的形式`
    while ((q & 1) == 0) {s++; q >>= 1;} //`计算$q$和$s$`
    LL w = qpow(g, q, P);                //`先计算$2^s$次原根$g^{\frac{p-1}{2^s}}$`
    LL invw = qpow(w, P - 2, P);         //`$2^s$次原根的逆元`
    for (int h = s; h >= 0; h--) {
        e[h] = w;                        //`$e[h]=g^{\frac{p-1}{2^h}}$`
        e[-h] = invw;                    //`$e[-h]=g^{-\frac{p-1}{2^h}}$`
        w = (w * w) % P;                 //`$2^{h}$次原根的平方等于$2^{h-1}$次原根`
        invw = (invw * invw) % P;
    }
}

void ntt(LL* f, const int& h, const int& type) {
    if (h == 0) return;                         //`递归出口`
    LL f0[1 << (h - 1)], f1[1 << (h - 1)];      //`新建两个大小为$2^{h-1}$的数组`
    int n = 1 << h;                             //`$n=2^h$`
    for (int i = 0; i < n; i += 2) f0[i / 2] = f[i]; //`偶数项复制到f0`
    ntt(f0, h - 1, type);                       //`将f0转成点值,$f0[k]=f_0(g_{n/2}^{k})$`
    for (int i = 1; i < n; i += 2) f1[i / 2] = f[i]; //`奇数项复制到f1`
    ntt(f1, h - 1, type);                       //`将f1转成点值,$f1[k]=f_1(g_{n/2}^{k})$`
    LL w = e[type * h];                         //`得到$n$次原根$g_n=e[h]=g^{\frac{p-1}{2^h}}$`
    LL x = 1;                                   //`用变量x计算$g_n^k$`
    for (int k = 0; k < n / 2; k++) {           //`$k$的枚举范围是$n/2$`
        f[k] = f0[k] + x * f1[k];               //`$f(g_n^k) \equiv f_0(g_{n/2}^{k})+g_n^kf_1(g_{n/2}^{k})$`
        f[k] %= P;                              //`$\pmod p$`
        f[k + n / 2] = f0[k] - x * f1[k];       //`$f(g_n^{k+n/2}) \equiv f_0(g_{n/2}^{k})-g_n^kf_1(g_{n/2}^{k})$`
        f[k + n / 2] %= P;                      //`$\pmod p$`
        x *= w; x %= P;                         //`保持$x \equiv g_n^k  \pmod p$`
    }
}

void polyInv(int n, LL* f, LL* g) {
    if (n == 1) {                          //`只有一项时`
        g[0] = qpow(f[0], P - 2, P);       //`令g[0]等于f[0]模$p$的乘法逆元`
        return;
    }
    polyInv((n + 1) >> 1, f, g);           //`递归调用polyInv,计算$\bmod {x^{\left\lceil \frac{n}{2} \right\rceil}}$的逆`
    int deg = 1, h = 0;
    while (deg < 2 * n) {deg <<= 1; h++;}  //`取$2^h$为大于等于$2n$的最小的$2$的幂`
    LL a[deg];
    for(int i=0;i<n;i++)a[i]=f[i];          //`只将f的前$n$项复制到数组a`
    for(int i=n;i<deg;i++)a[i]=0;
    ntt(a, h, 1);                          //`NTT求点值`
    ntt(g, h, 1);                          //`NTT求点值`
    for (int i = 0; i < deg; i++) {
        g[i] = (2 - a[i] * g[i] % P) * g[i] % P; //`计算$2H-FH^2$的点值`
    }
    ntt(g, h, -1);                         //`NTT求系数`
    LL inv = qpow(deg, P - 2, P);
    for (int i = 0; i < deg; i++) {
        g[i] = (g[i] * inv) % P;
        g[i] = (g[i] + P) % P;
    }
    for(int i=n;i<deg;i++) g[i]=0;//`从$n$次项开始的值设为0`
}


int n, m;
LL A[MAXN];
vector<LL> f[MAXN/2];
vector<LL> g[MAXN/2];
LL F1[MAXN*4],F2[MAXN*4],G1[MAXN*4],G2[MAXN*4];
void addFrac(int i1,int i2){
    int limit=f[i1].size()+f[i2].size()-1;//分母f1*f2的项数
    int deg=1,h=0;
    while(deg<limit) {deg<<=1;h++;}
    for(uint i=0;i<f[i1].size();i++) F1[i]=f[i1][i];
    for(int i=f[i1].size();i<deg;i++) F1[i]=0;
    for(uint i=0;i<f[i2].size();i++) F2[i]=f[i2][i];
    for(int i=f[i2].size();i<deg;i++) F2[i]=0;
    for(uint i=0;i<g[i1].size();i++) G1[i]=g[i1][i];
    for(int i=g[i1].size();i<deg;i++) G1[i]=0;
    for(uint i=0;i<g[i2].size();i++) G2[i]=g[i2][i];
    for(int i=g[i2].size();i<deg;i++) G2[i]=0;
    // printf("--------------------------------------------\n");
    // printf("f1:");for(int i=0;i<deg;i++)printf("%lld ",F1[i]);printf("\n");
    // printf("g1:");for(int i=0;i<deg;i++)printf("%lld ",G1[i]);printf("\n");
    // printf("f2:");for(int i=0;i<deg;i++)printf("%lld ",F2[i]);printf("\n");
    // printf("g2:");for(int i=0;i<deg;i++)printf("%lld ",G2[i]);printf("\n");
    ntt(F1,h,1);ntt(F2,h,1);
    ntt(G1,h,1);ntt(G2,h,1);
    for(int i=0;i<deg;i++) {
        G1[i]=(G1[i]*F2[i]%P+G2[i]*F1[i]%P)%P;
        F1[i]=F1[i]*F2[i]%P;
    }
    ntt(F1,h,-1);
    ntt(G1,h,-1);
    LL inv = qpow(deg, P - 2, P);
    for (int i = 0; i < deg; i++) {
        F1[i] = ((F1[i] * inv) % P + P)%P;
        G1[i] = ((G1[i] * inv) % P + P)%P;
    }
    limit=min(m+1,limit);
    f[i1].clear();g[i1].clear();
    for(int i=0;i<limit;i++) f[i1].push_back(F1[i]);
    for(int i=0;i<limit;i++) g[i1].push_back(G1[i]);
    // printf("f1:");for(uint i=0;i<f[i1].size();i++)printf("%lld ",f[i1][i]);printf("\n");
    // printf("g1:");for(uint i=0;i<g[i1].size();i++)printf("%lld ",g[i1][i]);printf("\n");    
}

int main() {
    primeroot(e, P, 3);
    scanf("%d%d", &n, &m);
    for (int i = 0; i < n; i++) {scanf("%lld", &A[i]);}
    for(int i=0;i<n/2;i++){
        f[i].push_back(1);
        f[i].push_back(-(A[i]+A[i+n/2])%P);
        f[i].push_back(A[i]*A[i+n/2]%P);
        g[i].push_back(2);
        g[i].push_back(-(A[i]+A[i+n/2])%P);
    }
    if(n&1){
        f[n/2].push_back(1);
        f[n/2].push_back(-A[n-1]%P);
        g[n/2].push_back(1);
        n=n/2+1;
    }
    else{
        n=n/2;
    }
    while(n>1){
        for(int i=0;i<n/2;i++){
            addFrac(i,i+n/2);
        }
        if(n&1){
            f[n/2]=f[n-1];
            g[n/2]=g[n-1];
            n=n/2+1;
        }
        else{
            n=n/2;
        }
    }
    int deg=1,h=0;
    while(deg<2*m+1) {deg<<=1;h++;}
    for(uint i=0;i<f[0].size();i++) F1[i]=f[0][i];
    for(int i=f[0].size();i<deg;i++) F1[i]=0;
    for(uint i=0;i<g[0].size();i++) G1[i]=g[0][i];
    for(int i=g[0].size();i<deg;i++) G1[i]=0;
    // printf("f1:");for(int i=0;i<deg;i++)printf("%lld ",F1[i]);printf("\n");
    // printf("g1:");for(int i=0;i<deg;i++)printf("%lld ",G1[i]);printf("\n");
    memset(G2,0,sizeof(G2));
    polyInv(m+1,F1,G2);
    // printf("g2:");for(int i=0;i<deg;i++)printf("%lld ",G2[i]);printf("\n"); 
    ntt(G1,h,1);
    ntt(G2,h,1);
    for(int i=0;i<deg;i++) G1[i]=G1[i]*G2[i]%P;
    ntt(G1,h,-1);
    LL inv = qpow(deg, P - 2, P);
    for (int i = 0; i < deg; i++) {
        G1[i] = ((G1[i] * inv) % P + P)%P;
    }
    for(int i=1;i<=m;i++)printf("%lld ",G1[i]);printf("\n");
    return 0;
}
0