結果

問題 No.2272 多項式乗算 mod 258280327
ユーザー Dmitrii KozyrevDmitrii Kozyrev
提出日時 2023-04-16 04:44:09
言語 C++23
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 3,173 bytes
コンパイル時間 4,460 ms
コンパイル使用メモリ 286,496 KB
実行使用メモリ 16,256 KB
最終ジャッジ日時 2024-04-19 19:50:54
合計ジャッジ時間 31,817 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 826 ms
12,800 KB
testcase_01 AC 796 ms
12,800 KB
testcase_02 AC 802 ms
12,800 KB
testcase_03 AC 804 ms
12,800 KB
testcase_04 AC 818 ms
13,056 KB
testcase_05 AC 801 ms
12,928 KB
testcase_06 AC 823 ms
13,056 KB
testcase_07 AC 809 ms
12,928 KB
testcase_08 AC 815 ms
12,928 KB
testcase_09 AC 844 ms
13,056 KB
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 AC 848 ms
12,928 KB
testcase_16 AC 807 ms
12,928 KB
testcase_17 AC 824 ms
12,800 KB
testcase_18 AC 827 ms
12,928 KB
testcase_19 AC 833 ms
12,800 KB
testcase_20 AC 890 ms
12,928 KB
testcase_21 AC 894 ms
12,800 KB
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>

#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx,avx2,fma") 
using namespace std;
//using ll = long long;

const int mod = 258280327;

namespace {
    template<int n, typename T>
    void mult(const T *__restrict a, const T *__restrict b, T *__restrict res) {
        if (n <= 16) { // if length is small then naive multiplication if faster
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    //res[i + j] += (res[i + j] + (ll)a[i] * b[j]) % mod;
                    res[i + j] += a[i] * b[j];
                }
            }
            for(int i = 0; i <= (n-1) + (n-1); i++)
                res[i] %= mod;
        } else {
            // cout << n << endl;
            const int mid = n / 2;
            alignas(64) T btmp[n], E[n] = {};
            auto atmp = btmp + mid;
            for (int i = 0; i < mid; i++) {
                atmp[i] = (a[i] + (a[i + mid] - mod)) % mod; // atmp(x) - sum of two halfs a(x)
                //if(atmp[i] >= mod) atmp[i] -= mod;
                btmp[i] = (b[i] + (b[i + mid] - mod)) % mod; // btmp(x) - sum of two halfs b(x)
                //if(btmp[i] >= mod) btmp[i] -= mod;
            }
            // cout << "sum" << endl;
            mult<mid>(atmp, btmp, E); // Calculate E(x) = (alow(x) + ahigh(x)) * (blow(x) + bhigh(x))
            // cout << "mult1" << endl;
            mult<mid>(a + 0, b + 0, res); // Calculate rlow(x) = alow(x) * blow(x)
            // cout << "mult2" << endl;
            mult<mid>(a + mid, b + mid, res + n); // Calculate rhigh(x) = ahigh(x) * bhigh(x)
            // cout << "mult3" << endl;
            for (int i = 0; i < mid; i++) { // Then, calculate rmid(x) = E(x) - rlow(x) - rhigh(x) and write in memory
                const auto tmp = res[i + mid];
                res[i + mid] += E[i] - res[i] - res[i + 2 * mid];
                res[i + mid] %= mod;
                res[i + 2 * mid] += E[i + mid] - tmp - res[i + 3 * mid];
                res[i + 2 * mid] %= mod;
            }
            // cout << "done" << endl;
        }
    }
}

const int nmax = (1 << 12) * 49;
alignas(64) static int64_t a[nmax],b[nmax],ret[2 * nmax];

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    int n,m;
    cin >> n;
    for(int i = 0; i <= n;++i) {
        cin >> a[i];
        //a[i] %= mod;
        //a[i] %= mod;
    }
    while (n >= 0 && a[n] == 0) n--;
    cin >> m;
    for(int i = 0;i <= m;++i) {
        cin >> b[i];
        //b[i] %= mod;
        //b[i] %= mod;
    }
    while (m >= 0 && b[m] == 0) m--;
    if (n < 0 || m < 0) {
        //std::cout << "0\n0\n" << std::endl;
        return 0;
    }
    for (int i = 0; i <= n; i++)
        a[i] %= mod;
    for (int i = 0; i <= n; i++)
        b[i] %= mod;
    mult<nmax>(a, b, ret);
    for(int i = 0; i <= n + m;++i)
        ((ret[i] %= mod) += mod) %= mod;
    int leading = n+m;
    //while (leading > 0 && ret[leading] == 0) leading--;
    cout << leading << endl;
    for(int i = 0;i <= leading;++i){
        auto x = (ret[i] % mod + mod) % mod;
        cout << x << ' ';
    }
    cout << endl;
    return 0;
}
0