結果
問題 | No.1145 Sums of Powers |
ユーザー | ace_amuro |
提出日時 | 2021-01-13 19:32:00 |
言語 | C++14 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 1,213 ms / 2,000 ms |
コード長 | 6,862 bytes |
コンパイル時間 | 1,074 ms |
コンパイル使用メモリ | 83,824 KB |
実行使用メモリ | 45,656 KB |
最終ジャッジ日時 | 2024-11-22 15:01:30 |
合計ジャッジ時間 | 5,875 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 6 ms
9,216 KB |
testcase_01 | AC | 5 ms
9,088 KB |
testcase_02 | AC | 12 ms
9,344 KB |
testcase_03 | AC | 1,213 ms
45,484 KB |
testcase_04 | AC | 1,203 ms
45,504 KB |
testcase_05 | AC | 1,208 ms
45,656 KB |
ソースコード
#include<cstdio> #include<cstring> #include<iostream> #include<cmath> #include<vector> #include<algorithm> using namespace std; typedef long long LL; typedef unsigned int uint; LL ww[101]; LL* e = ww + 50; const int MAXN = 1e5 + 5; const LL P = 998244353; LL qpow(LL a, LL k, LL p) { LL c = 1; a %= p; while (k) { if (k & 1) c = (c * a) % p; k >>= 1; a = (a * a) % p; } return c; } void primeroot(LL* e, const LL& P, const LL& g) { int s = 0; LL q = P - 1; //`将$p$分解成$p=q\times2^s+1$的形式` while ((q & 1) == 0) {s++; q >>= 1;} //`计算$q$和$s$` LL w = qpow(g, q, P); //`先计算$2^s$次原根$g^{\frac{p-1}{2^s}}$` LL invw = qpow(w, P - 2, P); //`$2^s$次原根的逆元` for (int h = s; h >= 0; h--) { e[h] = w; //`$e[h]=g^{\frac{p-1}{2^h}}$` e[-h] = invw; //`$e[-h]=g^{-\frac{p-1}{2^h}}$` w = (w * w) % P; //`$2^{h}$次原根的平方等于$2^{h-1}$次原根` invw = (invw * invw) % P; } } void ntt(LL* f, const int& h, const int& type) { if (h == 0) return; //`递归出口` LL f0[1 << (h - 1)], f1[1 << (h - 1)]; //`新建两个大小为$2^{h-1}$的数组` int n = 1 << h; //`$n=2^h$` for (int i = 0; i < n; i += 2) f0[i / 2] = f[i]; //`偶数项复制到f0` ntt(f0, h - 1, type); //`将f0转成点值,$f0[k]=f_0(g_{n/2}^{k})$` for (int i = 1; i < n; i += 2) f1[i / 2] = f[i]; //`奇数项复制到f1` ntt(f1, h - 1, type); //`将f1转成点值,$f1[k]=f_1(g_{n/2}^{k})$` LL w = e[type * h]; //`得到$n$次原根$g_n=e[h]=g^{\frac{p-1}{2^h}}$` LL x = 1; //`用变量x计算$g_n^k$` for (int k = 0; k < n / 2; k++) { //`$k$的枚举范围是$n/2$` f[k] = f0[k] + x * f1[k]; //`$f(g_n^k) \equiv f_0(g_{n/2}^{k})+g_n^kf_1(g_{n/2}^{k})$` f[k] %= P; //`$\pmod p$` f[k + n / 2] = f0[k] - x * f1[k]; //`$f(g_n^{k+n/2}) \equiv f_0(g_{n/2}^{k})-g_n^kf_1(g_{n/2}^{k})$` f[k + n / 2] %= P; //`$\pmod p$` x *= w; x %= P; //`保持$x \equiv g_n^k \pmod p$` } } void polyInv(int n, LL* f, LL* g) { if (n == 1) { //`只有一项时` g[0] = qpow(f[0], P - 2, P); //`令g[0]等于f[0]模$p$的乘法逆元` return; } polyInv((n + 1) >> 1, f, g); //`递归调用polyInv,计算$\bmod {x^{\left\lceil \frac{n}{2} \right\rceil}}$的逆` int deg = 1, h = 0; while (deg < 2 * n) {deg <<= 1; h++;} //`取$2^h$为大于等于$2n$的最小的$2$的幂` LL a[deg]; for(int i=0;i<n;i++)a[i]=f[i]; //`只将f的前$n$项复制到数组a` for(int i=n;i<deg;i++)a[i]=0; ntt(a, h, 1); //`NTT求点值` ntt(g, h, 1); //`NTT求点值` for (int i = 0; i < deg; i++) { g[i] = (2 - a[i] * g[i] % P) * g[i] % P; //`计算$2H-FH^2$的点值` } ntt(g, h, -1); //`NTT求系数` LL inv = qpow(deg, P - 2, P); for (int i = 0; i < deg; i++) { g[i] = (g[i] * inv) % P; g[i] = (g[i] + P) % P; } for(int i=n;i<deg;i++) g[i]=0;//`从$n$次项开始的值设为0` } int n, m; LL A[MAXN]; vector<LL> f[MAXN/2]; vector<LL> g[MAXN/2]; LL F1[MAXN*4],F2[MAXN*4],G1[MAXN*4],G2[MAXN*4]; void addFrac(int i1,int i2){ int limit=f[i1].size()+f[i2].size()-1;//分母f1*f2的项数 int deg=1,h=0; while(deg<limit) {deg<<=1;h++;} for(uint i=0;i<f[i1].size();i++) F1[i]=f[i1][i]; for(int i=f[i1].size();i<deg;i++) F1[i]=0; for(uint i=0;i<f[i2].size();i++) F2[i]=f[i2][i]; for(int i=f[i2].size();i<deg;i++) F2[i]=0; for(uint i=0;i<g[i1].size();i++) G1[i]=g[i1][i]; for(int i=g[i1].size();i<deg;i++) G1[i]=0; for(uint i=0;i<g[i2].size();i++) G2[i]=g[i2][i]; for(int i=g[i2].size();i<deg;i++) G2[i]=0; // printf("--------------------------------------------\n"); // printf("f1:");for(int i=0;i<deg;i++)printf("%lld ",F1[i]);printf("\n"); // printf("g1:");for(int i=0;i<deg;i++)printf("%lld ",G1[i]);printf("\n"); // printf("f2:");for(int i=0;i<deg;i++)printf("%lld ",F2[i]);printf("\n"); // printf("g2:");for(int i=0;i<deg;i++)printf("%lld ",G2[i]);printf("\n"); ntt(F1,h,1);ntt(F2,h,1); ntt(G1,h,1);ntt(G2,h,1); for(int i=0;i<deg;i++) { G1[i]=(G1[i]*F2[i]%P+G2[i]*F1[i]%P)%P; F1[i]=F1[i]*F2[i]%P; } ntt(F1,h,-1); ntt(G1,h,-1); LL inv = qpow(deg, P - 2, P); for (int i = 0; i < deg; i++) { F1[i] = ((F1[i] * inv) % P + P)%P; G1[i] = ((G1[i] * inv) % P + P)%P; } limit=min(m+1,limit); f[i1].clear();g[i1].clear(); for(int i=0;i<limit;i++) f[i1].push_back(F1[i]); for(int i=0;i<limit;i++) g[i1].push_back(G1[i]); // printf("f1:");for(uint i=0;i<f[i1].size();i++)printf("%lld ",f[i1][i]);printf("\n"); // printf("g1:");for(uint i=0;i<g[i1].size();i++)printf("%lld ",g[i1][i]);printf("\n"); } int main() { primeroot(e, P, 3); scanf("%d%d", &n, &m); for (int i = 0; i < n; i++) {scanf("%lld", &A[i]);} for(int i=0;i<n/2;i++){ f[i].push_back(1); f[i].push_back(-(A[i]+A[i+n/2])%P); f[i].push_back(A[i]*A[i+n/2]%P); g[i].push_back(2); g[i].push_back(-(A[i]+A[i+n/2])%P); } if(n&1){ f[n/2].push_back(1); f[n/2].push_back(-A[n-1]%P); g[n/2].push_back(1); n=n/2+1; } else{ n=n/2; } while(n>1){ for(int i=0;i<n/2;i++){ addFrac(i,i+n/2); } if(n&1){ f[n/2]=f[n-1]; g[n/2]=g[n-1]; n=n/2+1; } else{ n=n/2; } } int deg=1,h=0; while(deg<2*m+1) {deg<<=1;h++;} for(uint i=0;i<f[0].size();i++) F1[i]=f[0][i]; for(int i=f[0].size();i<deg;i++) F1[i]=0; for(uint i=0;i<g[0].size();i++) G1[i]=g[0][i]; for(int i=g[0].size();i<deg;i++) G1[i]=0; // printf("f1:");for(int i=0;i<deg;i++)printf("%lld ",F1[i]);printf("\n"); // printf("g1:");for(int i=0;i<deg;i++)printf("%lld ",G1[i]);printf("\n"); memset(G2,0,sizeof(G2)); polyInv(m+1,F1,G2); // printf("g2:");for(int i=0;i<deg;i++)printf("%lld ",G2[i]);printf("\n"); ntt(G1,h,1); ntt(G2,h,1); for(int i=0;i<deg;i++) G1[i]=G1[i]*G2[i]%P; ntt(G1,h,-1); LL inv = qpow(deg, P - 2, P); for (int i = 0; i < deg; i++) { G1[i] = ((G1[i] * inv) % P + P)%P; } for(int i=1;i<=m;i++)printf("%lld ",G1[i]);printf("\n"); return 0; }