結果

問題 No.2272 多項式乗算 mod 258280327
ユーザー MarioYC
提出日時 2023-04-14 23:24:10
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
TLE  
実行時間 -
コード長 2,457 bytes
コンパイル時間 2,454 ms
コンパイル使用メモリ 207,040 KB
最終ジャッジ日時 2025-02-12 08:20:28
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 33
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mod = 258280327;
#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx,avx2,fma")
namespace {
template<int n, typename T>
void mult(const T *__restrict a, const T *__restrict b, T *__restrict res) {
if (n <= 64) { // if length is small then naive multiplication if faster
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
res[i + j] = (res[i + j] + (ll)a[i] * b[j]) % mod;
}
}
} else {
// cout << n << endl;
const int mid = n / 2;
alignas(64) T btmp[n], E[n] = {};
auto atmp = btmp + mid;
for (int i = 0; i < mid; i++) {
atmp[i] = a[i] + a[i + mid]; // atmp(x) - sum of two halfs a(x)
if(atmp[i] >= mod) atmp[i] -= mod;
btmp[i] = b[i] + b[i + mid]; // btmp(x) - sum of two halfs b(x)
if(btmp[i] >= mod) btmp[i] -= mod;
}
// cout << "sum" << endl;
mult<mid>(atmp, btmp, E); // Calculate E(x) = (alow(x) + ahigh(x)) * (blow(x) + bhigh(x))
// cout << "mult1" << endl;
mult<mid>(a + 0, b + 0, res); // Calculate rlow(x) = alow(x) * blow(x)
// cout << "mult2" << endl;
mult<mid>(a + mid, b + mid, res + n); // Calculate rhigh(x) = ahigh(x) * bhigh(x)
// cout << "mult3" << endl;
for (int i = 0; i < mid; i++) { // Then, calculate rmid(x) = E(x) - rlow(x) - rhigh(x) and write in memory
const auto tmp = res[i + mid];
res[i + mid] += E[i] - res[i] - res[i + 2 * mid];
res[i + mid] %= mod;
res[i + 2 * mid] += E[i + mid] - tmp - res[i + 3 * mid];
res[i + 2 * mid] %= mod;
}
// cout << "done" << endl;
}
}
}
const int nmax = (1 << 12) * 49;
alignas(64) static int a[nmax],b[nmax],ret[2 * nmax];
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
int n,m;
cin >> n;
for(int i = 0;i <= n;++i) cin >> a[i];
cin >> m;
for(int i = 0;i <= m;++i) cin >> b[i];
mult<nmax, int>(a, b, ret);
cout << n + m << endl;
for(int i = 0;i <= n + m;++i){
auto x = ret[i];
if(ret[i] < 0) ret[i] += mod;
cout << x << ' ';
}
cout << endl;
return 0;
}
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0