結果

問題 No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
ユーザー KoDKoD
提出日時 2020-05-29 22:09:36
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 416 ms / 3,500 ms
コード長 17,078 bytes
コンパイル時間 2,617 ms
コンパイル使用メモリ 210,324 KB
実行使用メモリ 53,768 KB
最終ジャッジ日時 2024-04-23 22:46:32
合計ジャッジ時間 11,093 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
6,816 KB
testcase_01 AC 2 ms
6,944 KB
testcase_02 AC 1 ms
6,944 KB
testcase_03 AC 9 ms
6,940 KB
testcase_04 AC 6 ms
6,940 KB
testcase_05 AC 6 ms
6,940 KB
testcase_06 AC 5 ms
6,940 KB
testcase_07 AC 6 ms
6,944 KB
testcase_08 AC 6 ms
6,944 KB
testcase_09 AC 6 ms
6,944 KB
testcase_10 AC 3 ms
6,940 KB
testcase_11 AC 4 ms
6,944 KB
testcase_12 AC 3 ms
6,944 KB
testcase_13 AC 393 ms
53,636 KB
testcase_14 AC 406 ms
53,636 KB
testcase_15 AC 416 ms
53,512 KB
testcase_16 AC 405 ms
53,636 KB
testcase_17 AC 406 ms
53,508 KB
testcase_18 AC 397 ms
53,460 KB
testcase_19 AC 402 ms
53,516 KB
testcase_20 AC 398 ms
53,512 KB
testcase_21 AC 395 ms
53,768 KB
testcase_22 AC 411 ms
53,764 KB
testcase_23 AC 404 ms
53,640 KB
testcase_24 AC 401 ms
53,636 KB
testcase_25 AC 394 ms
53,644 KB
testcase_26 AC 402 ms
53,636 KB
testcase_27 AC 407 ms
53,624 KB
testcase_28 AC 394 ms
53,516 KB
testcase_29 AC 401 ms
53,640 KB
testcase_30 AC 398 ms
53,640 KB
testcase_31 AC 1 ms
6,940 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

/**
 * @brief 高速入出力
 * @author えびちゃん
 * @see https://qiita.com/rsk0315_h4x/items/17a9cb12e0de5fd918f4
 */
namespace fast {
static constexpr size_t buf_size = 1 << 17;
static constexpr size_t margin = 1;
static char inbuf[buf_size + margin] = {};
static __attribute__((aligned(8))) char outbuf[buf_size + margin] = {};
static __attribute__((aligned(8))) char minibuf[32];
static constexpr size_t int_digits = 20;  // 18446744073709551615
static constexpr uintmax_t digit_mask = 0x3030303030303030;
static constexpr uintmax_t first_mask = 0x00FF00FF00FF00FF;
static constexpr uintmax_t second_mask = 0x0000FFFF0000FFFF;
static constexpr uintmax_t third_mask = 0x00000000FFFFFFFF;
static constexpr uintmax_t tenpow[] = {
    1,
    10,
    100,
    1000,
    10000,
    100000,
    1000000,
    10000000,
    100000000,
    1000000000,
    10000000000,
    100000000000,
    1000000000000,
    10000000000000,
    100000000000000,
    1000000000000000,
    10000000000000000,
    100000000000000000,
    1000000000000000000,
    10000000000000000000u,
};
static __attribute__((
    aligned(8))) char inttab[40000] = {};  // 4-digit integers (10000 many)
static char S_sep = ' ', S_end = '\n';
template <typename Tp>
using enable_if_integral = std::enable_if<std::is_integral<Tp>::value, Tp>;

class scanner {
  char *pos = inbuf;
  char *endpos = inbuf + buf_size;

  void M_read_from_stdin() { endpos = inbuf + fread(pos, 1, buf_size, stdin); }
  void M_reread_from_stdin() {
    ptrdiff_t len = endpos - pos;
    if (!(inbuf + len <= pos)) return;
    memcpy(inbuf, pos, len);
    char *tmp = inbuf + len;
    endpos = tmp + fread(tmp, 1, buf_size - len, stdin);
    *endpos = 0;
    pos = inbuf;
  }

 public:
  scanner() { M_read_from_stdin(); }

  template <typename Integral,
            typename enable_if_integral<Integral>::type * = nullptr>
  void scan_parallel(Integral &x) {
    if (__builtin_expect(endpos <= pos + int_digits, 0)) M_reread_from_stdin();
    bool ends = false;
    typename std::make_unsigned<Integral>::type y = 0;
    bool neg = false;
    if (std::is_signed<Integral>::value && *pos == '-') {
      neg = true;
      ++pos;
    }
    do {
      memcpy(minibuf, pos, 8);
      long c = *(long *)minibuf;
      long d = (c & digit_mask) ^ digit_mask;
      int skip = 8;
      int shift = 8;
      if (d) {
        int ctz = __builtin_ctzl(d);
        if (ctz == 4) break;
        c &= (1L << (ctz - 5)) - 1;
        int discarded = (68 - ctz) / 8;
        shift -= discarded;
        c <<= discarded * 8;
        skip -= discarded;
        ends = true;
      }
      c |= digit_mask;
      c ^= digit_mask;
      c = ((c >> 8) + c * 10) & first_mask;
      c = ((c >> 16) + c * 100) & second_mask;
      c = ((c >> 32) + c * 10000) & third_mask;
      y = y * tenpow[shift] + c;
      pos += skip;
    } while (!ends);
    x = (neg ? -y : y);
    ++pos;
  }

  template <typename Integral,
            typename enable_if_integral<Integral>::type * = nullptr>
  void scan_serial(Integral &x) {
    if (__builtin_expect(endpos <= pos + int_digits, 0)) M_reread_from_stdin();
    bool neg = false;
    if (std::is_signed<Integral>::value && *pos == '-') {
      neg = true;
      ++pos;
    }
    typename std::make_unsigned<Integral>::type y = *pos - '0';
    while (*++pos >= '0') y = 10 * y + *pos - '0';
    x = (neg ? -y : y);
    ++pos;
  }

  template <typename Integral,
            typename enable_if_integral<Integral>::type * = nullptr>
  // Use scan_parallel(x) only when x may be too large (about 10^12).
  // Otherwise, even when x <= 10^9, scan_serial(x) may be faster.
  void scan(Integral &x) {
    scan_parallel(x);
  }

  void scan_serial(std::string &s) {
    // until first whitespace
    s = "";
    do {
      char *startpos = pos;
      while (*pos > ' ') ++pos;
      s += std::string(startpos, pos);
      if (*pos != 0) {
        ++pos;  // skip the space
        break;
      }
      M_reread_from_stdin();
    } while (true);
  }

  void scan(std::string &s) { scan_serial(s); }

  template <typename Tp, typename... Args>
  void scan(Tp &x, Args &&... xs) {
    scan(x);
    scan(std::forward<Args>(xs)...);
  }
};

class printer {
  char *pos = outbuf;

  void M_flush_stdout() {
    fwrite(outbuf, 1, pos - outbuf, stdout);
    pos = outbuf;
  }

  static int S_int_digits(uintmax_t n) {
    if (n < tenpow[16]) {                 // 1
      if (n < tenpow[8]) {                // 2
        if (n < tenpow[4]) {              // 3
          if (n < tenpow[2]) {            // 4
            if (n < tenpow[1]) return 1;  // 5
            return 2;                     // 5
          }
          if (n < tenpow[3]) return 3;  // 4
          return 4;                     // 4
        }
        if (n < tenpow[6]) {            // 4
          if (n < tenpow[5]) return 5;  // 5
          return 6;                     // 5
        }
        if (n < tenpow[7]) return 7;  // 5
        return 8;                     // 5
      }
      if (n < tenpow[12]) {             // 3
        if (n < tenpow[10]) {           // 4
          if (n < tenpow[9]) return 9;  // 5
          return 10;                    // 5
        }
        if (n < tenpow[11]) return 11;  // 5
        return 12;                      // 5
      }
      if (n < tenpow[14]) {             // 4
        if (n < tenpow[13]) return 13;  // 5
        return 14;                      // 5
      }
      if (n < tenpow[15]) return 15;  // 5
      return 16;                      // 5
    }
    if (n < tenpow[18]) {             // 2
      if (n < tenpow[17]) return 17;  // 3
      return 18;                      // 3
    }
    return 19;  // 2
    // if (n < tenpow[19]) return 19;  // 3
    // return 20;  // 3
  }

  void M_precompute() {
    unsigned long const digit1 = 0x0200000002000000;
    unsigned long const digit2 = 0xf600fffff6010000;
    unsigned long const digit3 = 0xfff600fffff60100;
    unsigned long const digit4 = 0xfffff600fffff601;
    unsigned long counter = 0x3130303030303030;
    for (int i = 0, i4 = 0; i4 < 10; ++i4, counter += digit4)
      for (int i3 = 0; i3 < 10; ++i3, counter += digit3)
        for (int i2 = 0; i2 < 10; ++i2, counter += digit2)
          for (int i1 = 0; i1 < 5; ++i1, ++i, counter += digit1)
            *((unsigned long *)inttab + i) = counter;
  }

 public:
  printer() { M_precompute(); }
  ~printer() { M_flush_stdout(); }

  void print(char c) {
    if (__builtin_expect(pos + 1 >= outbuf + buf_size, 0)) M_flush_stdout();
    *pos++ = c;
  }

  template <size_t N>
  void print(char const (&s)[N]) {
    if (__builtin_expect(pos + N >= outbuf + buf_size, 0)) M_flush_stdout();
    memcpy(pos, s, N - 1);
    pos += N - 1;
  }

  void print(char const *s) {
    // FIXME: strlen や memcpy などで定数倍高速化したい
    while (*s != 0) {
      *pos++ = *s++;
      if (pos == outbuf + buf_size) M_flush_stdout();
    }
  }

  void print(std::string const &s) { print(s.data()); }

  template <typename Integral,
            typename enable_if_integral<Integral>::type * = nullptr>
  void print(Integral x) {
    if (__builtin_expect(pos + int_digits >= outbuf + buf_size, 0))
      M_flush_stdout();
    if (x == 0) {
      *pos++ = '0';
      return;
    }
    if (x < 0) {
      *pos++ = '-';
      if (__builtin_expect(x == std::numeric_limits<Integral>::min(), 0)) {
        switch (sizeof x) {
          case 2:
            print("32768");
            return;
          case 4:
            print("2147483648");
            return;
          case 8:
            print("9223372036854775808");
            return;
        }
      }
      x = -x;
    }

    typename std::make_unsigned<Integral>::type y = x;
    int len = S_int_digits(y);
    pos += len;
    char *tmp = pos;
    int w = (pos - outbuf) & 3;
    if (w > len) w = len;
    for (int i = w; i > 0; --i) {
      *--tmp = y % 10 + '0';
      y /= 10;
    }
    len -= w;
    while (len >= 4) {
      tmp -= 4;
      *(unsigned *)tmp = *((unsigned *)inttab + (y % 10000));
      len -= 4;
      if (len) y /= 10000;
    }
    while (len-- > 0) {
      *--tmp = y % 10 + '0';
      y /= 10;
    }
  }

  template <typename Tp, typename... Args>
  void print(Tp const &x, Args &&... xs) {
    if (sizeof...(Args) > 0) {
      print(x);
      print(S_sep);
      print(std::forward<Args>(xs)...);
    }
  }

  template <typename Tp>
  void println(Tp const &x) {
    print(x), print(S_end);
  }

  template <typename Tp, typename... Args>
  void println(Tp const &x, Args &&... xs) {
    print(x, std::forward<Args>(xs)...);
    print(S_end);
  }

  static void set_sep(char c) { S_sep = c; }
  static void set_end(char c) { S_end = c; }
};
}  // namespace fast

fast::scanner fastin;
fast::printer fastout;

static constexpr uint32_t get_r(int mod) {
  uint64_t ret = 1, m = mod, n = mod - 2;
  while (n) {
    ret = uint32_t(ret * m);
    m = uint32_t(m * m);
    n >>= 1;
  }
  return ret;
};

template <uint32_t mod>
struct LazyMontgomeryModInt {
  using mint = LazyMontgomeryModInt;
  using i32 = int32_t;
  using u32 = uint32_t;
  using u64 = uint64_t;

  static constexpr u32 r = get_r(mod);
  static constexpr u32 n2 = -u64(mod) % mod;
  static_assert(r * mod == 1, "invalid, r * mod != 1");
  static_assert(mod < (1 << 30), "invalid, mod >= 2 ^ 30");
  static_assert((mod & 1) == 1, "invalid, mod % 2 == 0");

  u32 a;

  LazyMontgomeryModInt() : a(0) {}
  LazyMontgomeryModInt(const int64_t &b) : a(reduce(u64(b % mod + mod) * n2)){};

  static u32 reduce(const u64 &b) {
    return u32(b >> 32) + mod - u32((u64(u32(b) * r) * mod) >> 32);
  }

  mint &operator+=(const mint &b) {
    if (i32(a += b.a - 2 * mod) < 0) a += 2 * mod;
    return *this;
  }

  mint &operator-=(const mint &b) {
    if (i32(a -= b.a) < 0) a += 2 * mod;
    return *this;
  }

  mint &operator*=(const mint &b) {
    a = reduce(u64(a) * b.a);
    return *this;
  }

  mint &operator/=(const mint &b) {
    *this *= b.inverse();
    return *this;
  }

  mint operator+(const mint &b) const { return mint(*this) += b; }
  mint operator-(const mint &b) const { return mint(*this) -= b; }
  mint operator*(const mint &b) const { return mint(*this) *= b; }
  mint operator/(const mint &b) const { return mint(*this) /= b; }

  u32 get() const {
    u32 ret = reduce(a);
    return ret >= mod ? ret - mod : ret;
  }

  mint pow(u64 n) const {
    mint ret(1), mul(*this);
    while (n > 0) {
      if (n & 1) ret *= mul;
      mul *= mul;
      n >>= 1;
    }
    return ret;
  }

  friend ostream &operator<<(ostream &os, const mint &b) {
    return os << b.get();
  }

  friend istream &operator>>(istream &is, mint &b) {
    int64_t t;
    is >> t;
    b = LazyMontgomeryModInt<mod>(t);
    return (is);
  }

  mint inverse() const { return pow(mod - 2); }

  static constexpr u32 get_mod() { return mod; }
};

static constexpr uint32_t get_pr(uint32_t mod) {
  using u64 = uint64_t;
  u64 ds[32] = {};
  int idx = 0;
  u64 m = mod - 1;
  for (u64 i = 2; i * i <= m; ++i) {
    if (m % i == 0) {
      ds[idx++] = i;
      while (m % i == 0) m /= i;
    }
  }
  if (m != 1) ds[idx++] = m;

  uint32_t pr = 2;
  while (1) {
    int flg = 1;
    for (int i = 0; i < idx; ++i) {
      u64 a = pr, b = (mod - 1) / ds[i], r = 1;
      while (b) {
        if (b & 1) r = r * a % mod;
        a = a * a % mod;
        b >>= 1;
      }
      if (r == 1) {
        flg = 0;
        break;
      }
    }
    if (flg == 1) break;
    ++pr;
  }
  return pr;
};

template <typename mint>
struct NTT {
  static constexpr uint32_t mod = mint::get_mod();
  static constexpr uint32_t pr = get_pr(mod);
  static constexpr int level = __builtin_ctzll(mod - 1);
  mint dw[level], dy[level];

  void setwy(int k) {
    mint w[level], y[level];
    w[k - 1] = mint(pr).pow((mod - 1) / (1 << k));
    y[k - 1] = w[k - 1].inverse();
    for (int i = k - 2; i > 0; --i)
      w[i] = w[i + 1] * w[i + 1], y[i] = y[i + 1] * y[i + 1];
    dw[1] = w[1], dy[1] = y[1], dw[2] = w[2], dy[2] = y[2];
    for (int i = 3; i < k; ++i) {
      dw[i] = dw[i - 1] * y[i - 2] * w[i];
      dy[i] = dy[i - 1] * w[i - 2] * y[i];
    }
  }

  void fft4(vector<mint> &a, int k) {
    if (k & 1) {
      int v = 1 << (k - 1);
      for (int j = 0; j < v; ++j) {
        mint ajv = a[j + v];
        a[j + v] = a[j] - ajv;
        a[j] += ajv;
      }
    }
    int u = 1 << (2 + (k & 1));
    int v = 1 << (k - 2 - (k & 1));
    mint one = mint(1);
    mint imag = dw[1];
    while (v) {
      // jh = 0
      {
        int j0 = 0;
        int j1 = v;
        int j2 = j1 + v;
        int j3 = j2 + v;
        for (; j0 < v; ++j0, ++j1, ++j2, ++j3) {
          mint t0 = a[j0], t1 = a[j1], t2 = a[j2], t3 = a[j3];
          mint t0p2 = t0 + t2, t1p3 = t1 + t3;
          mint t0m2 = t0 - t2, t1m3 = (t1 - t3) * imag;
          a[j0] = t0p2 + t1p3, a[j1] = t0p2 - t1p3;
          a[j2] = t0m2 + t1m3, a[j3] = t0m2 - t1m3;
        }
      }
      // jh >= 1
      mint ww = one, xx = one * dw[2], wx = one;
      for (int jh = 4; jh < u;) {
        ww = xx * xx, wx = ww * xx;
        int j0 = jh * v;
        int je = j0 + v;
        int j2 = je + v;
        for (; j0 < je; ++j0, ++j2) {
          mint t0 = a[j0], t1 = a[j0 + v] * xx, t2 = a[j2] * ww,
               t3 = a[j2 + v] * wx;
          mint t0p2 = t0 + t2, t1p3 = t1 + t3;
          mint t0m2 = t0 - t2, t1m3 = (t1 - t3) * imag;
          a[j0] = t0p2 + t1p3, a[j0 + v] = t0p2 - t1p3;
          a[j2] = t0m2 + t1m3, a[j2 + v] = t0m2 - t1m3;
        }
        xx *= dw[__builtin_ctzll((jh += 4))];
      }
      u <<= 2;
      v >>= 2;
    }
  }

  void ifft4(vector<mint> &a, int k) {
    int u = 1 << (k - 2);
    int v = 1;
    mint one = mint(1);
    mint imag = dy[1];
    while (u) {
      // jh = 0
      {
        int j0 = 0;
        int j1 = v;
        int j2 = v + v;
        int j3 = j2 + v;
        for (; j0 < v; ++j0, ++j1, ++j2, ++j3) {
          mint t0 = a[j0], t1 = a[j1], t2 = a[j2], t3 = a[j3];
          mint t0p1 = t0 + t1, t2p3 = t2 + t3;
          mint t0m1 = t0 - t1, t2m3 = (t2 - t3) * imag;
          a[j0] = t0p1 + t2p3, a[j2] = t0p1 - t2p3;
          a[j1] = t0m1 + t2m3, a[j3] = t0m1 - t2m3;
        }
      }
      // jh >= 1
      mint ww = one, xx = one * dy[2], yy = one;
      u <<= 2;
      for (int jh = 4; jh < u;) {
        ww = xx * xx, yy = xx * imag;
        int j0 = jh * v;
        int je = j0 + v;
        int j2 = je + v;
        for (; j0 < je; ++j0, ++j2) {
          mint t0 = a[j0], t1 = a[j0 + v], t2 = a[j2], t3 = a[j2 + v];
          mint t0p1 = t0 + t1, t2p3 = t2 + t3;
          mint t0m1 = (t0 - t1) * xx, t2m3 = (t2 - t3) * yy;
          a[j0] = t0p1 + t2p3, a[j2] = (t0p1 - t2p3) * ww;
          a[j0 + v] = t0m1 + t2m3, a[j2 + v] = (t0m1 - t2m3) * ww;
        }
        xx *= dy[__builtin_ctzll(jh += 4)];
      }
      u >>= 4;
      v <<= 2;
    }
    if (k & 1) {
      u = 1 << (k - 1);
      for (int j = 0; j < u; ++j) {
        mint ajv = a[j] - a[j + u];
        a[j] += a[j + u];
        a[j + u] = ajv;
      }
    }
  }

  vector<mint> multiply(const vector<mint> &a, const vector<mint> &b) {
    int l = a.size() + b.size() - 1;
    int k = 2, M = 4;
    while (M < l) M <<= 1, ++k;
    setwy(k);
    vector<mint> s(M), t(M);
    for (int i = 0; i < (int)a.size(); ++i) s[i] = a[i];
    for (int i = 0; i < (int)b.size(); ++i) t[i] = b[i];
    fft4(s, k);
    fft4(t, k);
    for (int i = 0; i < M; ++i) s[i] *= t[i];
    ifft4(s, k);
    s.resize(l);
    mint invm = mint(M).inverse();
    for (int i = 0; i < l; ++i) s[i] *= invm;
    return s;
  }
};

int main() {
  constexpr uint32_t MOD = 998244353;
  using mint = LazyMontgomeryModInt<MOD>;
  NTT<mint> ntt;

  // int N, M, n;
  // fastin.scan_serial(N);
  // fastin.scan_serial(M);
  // vector<mint> a(N), b(M);
  // for (int i = 0; i < N; ++i) {
  //   fastin.scan_serial(n);
  //   a[i].a = mint::reduce(uint64_t(n) * mint::n2);
  // }
  // for (int i = 0; i < M; ++i) {
  //   fastin.scan_serial(n);
  //   b[i].a = mint::reduce(uint64_t(n) * mint::n2);
  // }

  // auto c = ntt.multiply(a, b);
  // fastout.set_end(' ');
  // int l = N + M - 2;
  // for (int i = 0; i <= l; ++i) {
  //   if (i == l) fastout.set_end('\n');
  //   fastout.println(c[i].get());
  // }

  int N, Q;
  fastin.scan_serial(N);
  fastin.scan_serial(Q);

  int size = 1;
  while (size < N) size <<= 1;
  std::vector<std::vector<mint>> vec(size << 1);

  for (int i = 0; i < N; ++i) {
    uint64_t x;
    fastin.scan_parallel(x);
    auto &v = vec[size + i];
    v.resize(2);
    v[0].a = mint::reduce(((x - 1) % MOD) * mint::n2);
    v[1].a = mint::reduce(uint64_t(1) * mint::n2);
  }

  for (int i = size - 1; i > 0; --i) {
    const auto &l = vec[i << 1 | 0];
    const auto &r = vec[i << 1 | 1];
    if (l.empty()) {
      vec[i] = r;
    }
    else if (r.empty()) {
      vec[i] = l;
    }
    else {
      vec[i] = ntt.multiply(l, r);
    }
  }

  while(Q--) {
    int x;
    fastin.scan_serial(x);
    fastout.println(vec[1][x].get());
  }

  return 0;
}
0