
問題 No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
ユーザー SPARKLE_mathSPARKLE_math
提出日時 2023-11-10 20:11:28
言語 C++17(clang)
(17.0.6 + boost 1.87.0)
実行時間 657 ms / 3,500 ms
コード長 14,119 bytes
コンパイル時間 4,913 ms
コンパイル使用メモリ 150,328 KB
実行使用メモリ 14,976 KB
最終ジャッジ日時 2024-09-26 00:51:50
合計ジャッジ時間 19,382 ms
judge4 / judge2
ファイルパターン 結果
sample AC * 3
other AC * 29


diff #

#line 1 "main.cpp"
#include <iostream>
#include <algorithm>
#include <vector>
#include <iomanip>
#include <math.h>
#include <functional>
#include <map>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <queue>
#include <stack>
#include <cstring>
#include <assert.h>
#include <unistd.h>
#include <chrono>
#include <numeric>
#include <cstdint>
#include <variant>
#line 2 "/home/kokoro601/compro_library/Utils/FastIO.hpp"
#include <string.h>
#line 4 "/home/kokoro601/compro_library/Utils/FastIO.hpp"

namespace fastio{
    static constexpr size_t buf_size = 1 << 18;
    static constexpr size_t integer_size = 20;
    static constexpr size_t block_size = 10000;

    static char inbuf[buf_size + 1] = {};
    static char outbuf[buf_size + 1] = {};
    static char block_str[block_size * 4 + 1] = {};

    static constexpr uint64_t power10[] = {
        1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000,
        1000000000, 10000000000, 100000000000, 1000000000000, 10000000000000,
        100000000000000, 1000000000000000, 10000000000000000, 100000000000000000,
        1000000000000000000, 10000000000000000000u

    struct Scanner {
        size_t pos,end;

        void load() {
            end = fread(inbuf,1,buf_size,stdin);
            inbuf[end] = '\0';
        void reload() {
            size_t len = end - pos;
            memmove(inbuf,inbuf + pos,len);
            end = len + fread(inbuf + len,1,buf_size - len,stdin);
            inbuf[end] = '\0';
            pos = 0;
        void skip_space() {
            while(inbuf[pos] <= ' '){
                if(__builtin_expect(++pos == end, 0)) reload();
        char get_next() { return inbuf[pos++]; }
        char get_next_nonspace() {
            return inbuf[pos++];
        Scanner() { load(); }

        void scan(char& c) { c = get_next_nonspace(); }
        void scan(std::string& s){
            s = "";
            do {
                size_t start = pos;
                while (inbuf[pos] > ' ') pos++;
                s += std::string(inbuf + start, inbuf + pos);
                if (inbuf[pos] !='\0') break;
            } while (true);

        template <class T>
        typename std::enable_if<std::is_integral<T>::value, void>::type scan(T &x) {
            char c = get_next_nonspace();
            if(__builtin_expect(pos + integer_size >= end, 0)) reload();
            bool neg = false;
            if (c == '-') neg = true, x = 0;
            else x = c & 15;
            while ((c = get_next()) >= '0') x = x * 10 + (c & 15);
            if (neg) x = -x;

        template <class Head, class... Others>
        void scan(Head& head, Others&... others) {
            scan(head); scan(others...);

        template <class T>
        Scanner& operator >> (T& x) {
            scan(x); return *this;

    struct Printer {
        size_t pos = 0;
        void flush() {
            fwrite(outbuf, 1, pos, stdout);
            pos = 0;

        void pre_calc() {
            for (size_t i = 0; i < block_size; i++) {
                size_t j = 4, k = i;
                while (j--) {
                    block_str[i * 4 + j] = k % 10 + '0';
                    k /= 10;

        static constexpr size_t get_integer_size(uint64_t n) {
            if(n >= power10[10]) {
                if (n >= power10[19]) return 20;
                if (n >= power10[18]) return 19;
                if (n >= power10[17]) return 18;
                if (n >= power10[16]) return 17;
                if (n >= power10[15]) return 16;
                if (n >= power10[14]) return 15;
                if (n >= power10[13]) return 14;
                if (n >= power10[12]) return 13;
                if (n >= power10[11]) return 12;
                return 11;
            else {
                if (n >= power10[9]) return 10;
                if (n >= power10[8]) return 9;
                if (n >= power10[7]) return 8;
                if (n >= power10[6]) return 7;
                if (n >= power10[5]) return 6;
                if (n >= power10[4]) return 5;
                if (n >= power10[3]) return 4;
                if (n >= power10[2]) return 3;
                if (n >= power10[1]) return 2;
                return 1;

        Printer() { pre_calc(); }
        ~Printer() { flush(); }

        void print(char c){
            outbuf[pos++] = c;
            if (__builtin_expect(pos == buf_size, 0)) flush();
        void print(const char *s) {
            while(*s != 0) {
                outbuf[pos++] = *s++;
                // if (pos == buf_size) flush();
                if (__builtin_expect(pos == buf_size, 0)) flush();
        void print(const std::string& s) {
            for(auto c : s){
                outbuf[pos++] = c;
                // if (pos == buf_size) flush();
                if (__builtin_expect(pos == buf_size, 0)) flush();

        template <class T>
        typename std::enable_if<std::is_integral<T>::value, void>::type print(T x) {
            if (__builtin_expect(pos + integer_size >= buf_size, 0)) flush();
            if (x < 0) print('-'), x = -x;
            size_t digit = get_integer_size(x);
            size_t len = digit;
            while (len >= 4) {
                len -= 4;
                memcpy(outbuf + pos + len, block_str + (x % block_size) * 4, 4);
                x /= block_size;
            memcpy(outbuf + pos, block_str + x * 4 + (4 - len), len);
            pos += digit;

        template <class Head, class... Others>
        void print(const Head& head, const Others&... others){
            print(head); print(' '); print(others...);

        template <class... Args>
        void println(const Args&... args) {
            print(args...); print('\n');
        template <class T>
        Printer& operator << (const T& x) {
            print(x); return *this;

fastio::Scanner fin;
fastio::Printer fout;
#define cin fin
#define cout fout
#line 4 "/home/kokoro601/compro_library/Utils/ModInt.hpp"

template <uint32_t MD> struct ModInt {
    using M = ModInt;
    using uint = uint32_t;
    using ull = uint64_t;
    using ll = int64_t;
    uint v;
    ModInt(ll _v = 0) { set_v(_v % MD + MD); }
    M& set_v(uint _v) {
        v = (_v < MD) ? _v : _v - MD;
        return *this;
    explicit operator bool() const { return v != 0; }
    M operator-() const { return M() - *this; }
    M operator+(const M& r) const { return M().set_v(v + r.v); }
    M operator-(const M& r) const { return M().set_v(v + MD - r.v); } // "v + MD - r.v" can exceed MD, so set_v is needed
    M operator*(const M& r) const { return M().set_v(ull(v) * r.v % MD); }
    M operator/(const M& r) const { return *this * r.inv(); }
    M& operator+=(const M& r) { return *this = *this + r; }
    M& operator-=(const M& r) { return *this = *this - r; }
    M& operator*=(const M& r) { return *this = *this * r; }
    M& operator/=(const M& r) { return *this = *this / r; }
    bool operator==(const M& r) const { return v == r.v; }
    bool operator!=(const M& r) const { return v != r.v;}
    M pow(ull n) const {
        M x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        return r;
    M inv() const { return pow(MD - 2); }
    friend std::ostream& operator<<(std::ostream& os, const M& r) { return os << r.v; }
#line 2 "/home/kokoro601/compro_library/Math/ExtGCD.hpp"

// Calculate the solution of ax + by = gcd(a, b)
// and return gcd(a, b)
long long extGCD(long long a, long long b, long long &x, long long &y) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;

    long long d = extGCD(b, a % b, y, x);
    y -= (a / b) * x;
    return d;

long long gcd(long long a, long long b) {
    long long unused1, unused2;
    return extGCD(a, b, unused1, unused2);
#line 4 "/home/kokoro601/compro_library/Math/ModOp.hpp"

long long modpow(long long a, long long x, long long m) {
    long long ret = 1;
    while (x > 0) {
        if (x & 1) ret = ret * a % m;
        a = a * a % m;
        x >>= 1;
    return ret;

long long modinv(long long a, long long m) {
    long long x, _;
    extGCD(a, m, x, _);
    x %= m;
    if (x < 0) x += m;
    return x;

#line 7 "/home/kokoro601/compro_library/DataStructure/NTT.hpp"

template<uint32_t MD, uint32_t root> class NTT {
    using Mint = ModInt<MD>;

    // type : ift or not
    // a.size() should be less than 1 << 23
    void nft(bool type, std::vector<Mint>& a) const {
        int n = (int)a.size(), s = 0;
        while ((1 << s) < n) s++;
        assert(1 << s == n);

        // these are calculated only once because the type is static
        static std::vector<Mint> ep, iep;
        Mint g = root;
        while (ep.size() <= s) {
            ep.push_back(g.pow(Mint(-1).v / (1 << ep.size())));

        std::vector<Mint> b(n);
        // Stockham FFT
        // no need to perform bit reversal (but not in-place)
        // memory access is sequantial
        for (int i = 1; i <= s; i++) { 
            int w = 1 << (s - i);
            Mint base = type ? iep[i] : ep[i];
            Mint now = 1;
            for (int y = 0; y < n / 2; y += w) {
                for (int x = 0; x < w; x++) {
                    auto s = a[y << 1 | x];
                    auto t = now * a[y << 1 | x | w];
                    b[y | x] = s + t;
                    b[y | x | n >> 1] = s - t;
                now *= base;
            swap(a, b);

    std::vector<Mint> convolution(const std::vector<Mint> &a, const std::vector<Mint> &b) const {
        int n = (int)a.size(), m = (int)b.size();
        if (!n || !m) return {};
        int lg = 0;
        while ((1 << lg) < n + m - 1) lg++;
        int z = 1 << lg;
        // The type of a2 and b2 is std::vector<Mint> (not reference)
        // Therefore, a2 and b2 are copies of a and b, respectively.
        auto a2 = a, b2 = b;
        nft(false, a2);
        nft(false, b2);
        for (int i = 0; i < z; i++) a2[i] *= b2[i];
        nft(true, a2);
        a2.resize(n + m - 1);
        Mint iz = Mint(z).inv();
        for (int i = 0; i < n + m - 1; i++) a2[i] *= iz;
        return a2;

template<int MOD> class ArbitraryModNTT {
    using Mint = ModInt<MOD>;
    ArbitraryModNTT() {}

    std::vector<Mint> convolution(std::vector<Mint> &a, std::vector<Mint> &b) const {
        int n = a.size();
        int m = b.size();
        constexpr size_t MOD1 = 167772161;
        constexpr size_t MOD2 = 469762049;
        constexpr size_t MOD3 = 1224736769;
        using Mint1 = ModInt<MOD1>;
        using Mint2 = ModInt<MOD2>;
        using Mint3 = ModInt<MOD3>;

        NTT<MOD1, 3> ntt1;
        NTT<MOD2, 3> ntt2;
        NTT<MOD3, 3> ntt3;

        std::vector<Mint1> a1(n), b1(m);
        std::vector<Mint2> a2(n), b2(m);
        std::vector<Mint3> a3(n), b3(m);

        for (int i = 0; i < n; i++) a1[i] = (a[i].v) % MOD1, a2[i] = (a[i].v) % MOD2, a3[i] = (a[i].v) % MOD3;
        for (int i = 0; i < m; i++) b1[i] = (b[i].v) % MOD1, b2[i] = (b[i].v) % MOD2, b3[i] = (b[i].v) % MOD3;
        auto c1 = ntt1.convolution(a1, b1);
        auto c2 = ntt2.convolution(a2, b2);
        auto c3 = ntt3.convolution(a3, b3);

        std::vector<Mint> c(c1.size());
        const Mint2 m1_inv_m2 = modinv(MOD1, MOD2);
        const Mint3 m1m2_inv_m3 = modinv((Mint3(MOD1) * MOD2).v, MOD3);
        for (int i = 0; i < c1.size(); i++) {
            long long t1 = (m1_inv_m2 * ((long long)c2[i].v - c1[i].v)).v;
            long long t = (m1m2_inv_m3 * ((long long)c3[i].v - t1 * MOD1 - c1[i].v)).v;
            c[i] = Mint(t) * MOD1 * MOD2 + Mint(t1) * MOD1 + c1[i].v;

        return c;
#line 23 "main.cpp"
using namespace std;
using ll = long long;
using Graph = vector<vector<int>>;
using u128 = __uint128_t;
using u64 = uint64_t;

// 番兵は入っているとする
// ll garner(vector<ll> &b, vector<ll> &m) {
//     vector<ll> coeffs(m.size(), 1);
//     vector<ll> constants(m.size(), 0);
//     for (size_t k = 0; k < b.size(); k++) {
//         ll diff = (b[k] > constants[k]) ? (b[k] - constants[k]) : (b[k] - constants[k] + m[k]);
//         ll t = (diff * modinv(coeffs[k], m[k])) % m[k];
//         for (size_t i = k + 1; i < m.size(); i++) {
//             (constants[i] += t * coeffs[i]) %= m[i];
//             (coeffs[i] *= m[k]) %= m[i];
//         }
//     }

//     return constants.back();
// }

template<class T> void srt(T beg, size_t sz) {
    if (sz == 0 || sz == 1) return;
    auto pivot = beg[0];
    size_t cnt = 0;
    for (size_t i = 0; i < sz; i++) cnt += (beg[i] <= pivot);
    size_t idx1 = 0;
    size_t idx2 = cnt;
    for (size_t i = 0; i < sz; i++) 
        if (beg[i] <= pivot) swap(beg[i], beg[idx1++]);
        else swap(beg[i], beg[idx2++]);
    srt(beg, cnt);
    srt(beg + cnt, sz - cnt);

constexpr size_t MOD = 998244353;
using Mint = ModInt<MOD>;
int main() {
    int n, q; cin >> n >> q;
    deque<vector<Mint>> que;
    for (int i = 0; i < n; i++) {
        ll a; cin >> a;
        que.push_back({a-1, 1});

    NTT<MOD, 3> ntt;
    while (que.size() > 1) {
        auto v1 = que.front();
        auto v2 = que.front();
        auto v = ntt.convolution(v1, v2);

    auto v = que.front();
    for (int i = 0; i < q; i++) {
        int b; cin >> b;
        cout << v[b].v << "\n";
    return 0;