結果
問題 | No.2272 多項式乗算 mod 258280327 |
ユーザー |
|
提出日時 | 2023-04-14 23:24:10 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,457 bytes |
コンパイル時間 | 2,454 ms |
コンパイル使用メモリ | 207,040 KB |
最終ジャッジ日時 | 2025-02-12 08:20:28 |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | TLE * 33 |
ソースコード
#include <bits/stdc++.h>using namespace std;using ll = long long;const int mod = 258280327;#pragma GCC optimize("Ofast,unroll-loops")#pragma GCC target("avx,avx2,fma")namespace {template<int n, typename T>void mult(const T *__restrict a, const T *__restrict b, T *__restrict res) {if (n <= 64) { // if length is small then naive multiplication if fasterfor (int i = 0; i < n; i++) {for (int j = 0; j < n; j++) {res[i + j] = (res[i + j] + (ll)a[i] * b[j]) % mod;}}} else {// cout << n << endl;const int mid = n / 2;alignas(64) T btmp[n], E[n] = {};auto atmp = btmp + mid;for (int i = 0; i < mid; i++) {atmp[i] = a[i] + a[i + mid]; // atmp(x) - sum of two halfs a(x)if(atmp[i] >= mod) atmp[i] -= mod;btmp[i] = b[i] + b[i + mid]; // btmp(x) - sum of two halfs b(x)if(btmp[i] >= mod) btmp[i] -= mod;}// cout << "sum" << endl;mult<mid>(atmp, btmp, E); // Calculate E(x) = (alow(x) + ahigh(x)) * (blow(x) + bhigh(x))// cout << "mult1" << endl;mult<mid>(a + 0, b + 0, res); // Calculate rlow(x) = alow(x) * blow(x)// cout << "mult2" << endl;mult<mid>(a + mid, b + mid, res + n); // Calculate rhigh(x) = ahigh(x) * bhigh(x)// cout << "mult3" << endl;for (int i = 0; i < mid; i++) { // Then, calculate rmid(x) = E(x) - rlow(x) - rhigh(x) and write in memoryconst auto tmp = res[i + mid];res[i + mid] += E[i] - res[i] - res[i + 2 * mid];res[i + mid] %= mod;res[i + 2 * mid] += E[i + mid] - tmp - res[i + 3 * mid];res[i + 2 * mid] %= mod;}// cout << "done" << endl;}}}const int nmax = (1 << 12) * 49;alignas(64) static int a[nmax],b[nmax],ret[2 * nmax];int main(){ios_base::sync_with_stdio(false);cin.tie(0);int n,m;cin >> n;for(int i = 0;i <= n;++i) cin >> a[i];cin >> m;for(int i = 0;i <= m;++i) cin >> b[i];mult<nmax, int>(a, b, ret);cout << n + m << endl;for(int i = 0;i <= n + m;++i){auto x = ret[i];if(ret[i] < 0) ret[i] += mod;cout << x << ' ';}cout << endl;return 0;}