結果
| 問題 |
No.2272 多項式乗算 mod 258280327
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-04-16 04:23:52 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,623 bytes |
| コンパイル時間 | 3,413 ms |
| コンパイル使用メモリ | 253,264 KB |
| 実行使用メモリ | 15,872 KB |
| 最終ジャッジ日時 | 2024-10-11 11:23:28 |
| 合計ジャッジ時間 | 23,820 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 31 WA * 2 |
ソースコード
#include <bits/stdc++.h>
#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx,avx2,fma")
using namespace std;
using ll = long long;
const int mod = 258280327;
namespace {
template<int n, typename T>
void mult(const T *__restrict a, const T *__restrict b, T *__restrict res) {
if (n <= 64) { // if length is small then naive multiplication if faster
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
//res[i + j] += (res[i + j] + (ll)a[i] * b[j]) % mod;
res[i + j] += a[i] * b[j];
}
}
} else {
// cout << n << endl;
const int mid = n / 2;
alignas(64) T btmp[n], E[n] = {};
auto atmp = btmp + mid;
for (int i = 0; i < mid; i++) {
atmp[i] = (a[i] + a[i + mid]) % mod; // atmp(x) - sum of two halfs a(x)
//if(atmp[i] >= mod) atmp[i] -= mod;
btmp[i] = (b[i] + b[i + mid]) % mod; // btmp(x) - sum of two halfs b(x)
//if(btmp[i] >= mod) btmp[i] -= mod;
}
// cout << "sum" << endl;
mult<mid>(atmp, btmp, E); // Calculate E(x) = (alow(x) + ahigh(x)) * (blow(x) + bhigh(x))
// cout << "mult1" << endl;
mult<mid>(a + 0, b + 0, res); // Calculate rlow(x) = alow(x) * blow(x)
// cout << "mult2" << endl;
mult<mid>(a + mid, b + mid, res + n); // Calculate rhigh(x) = ahigh(x) * bhigh(x)
// cout << "mult3" << endl;
for (int i = 0; i < mid; i++) { // Then, calculate rmid(x) = E(x) - rlow(x) - rhigh(x) and write in memory
const auto tmp = res[i + mid];
res[i + mid] += E[i] - res[i] - res[i + 2 * mid];
res[i + mid] %= mod;
res[i + 2 * mid] += E[i + mid] - tmp - res[i + 3 * mid];
res[i + 2 * mid] %= mod;
}
// cout << "done" << endl;
}
}
}
const int nmax = (1 << 12) * 49;
alignas(64) static ll a[nmax],b[nmax],ret[2 * nmax];
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
int n,m;
cin >> n;
for(int i = 0;i <= n;++i) {
cin >> a[i];
a[i] %= mod;
//a[i] %= mod;
}
cin >> m;
for(int i = 0;i <= m;++i) {
cin >> b[i];
b[i] %= mod;
//b[i] %= mod;
}
mult<nmax>(a, b, ret);
cout << n + m << endl;
for(int i = 0;i <= n + m;++i){
auto x = (ret[i] % mod + mod) % mod;
cout << x << ' ';
}
cout << endl;
return 0;
}