結果

問題 No.206 数の積集合を求めるクエリ
ユーザー Min_25Min_25
提出日時 2017-11-19 18:29:40
言語 C++14
(gcc 13.2.0 + boost 1.83.0)
結果
AC  
実行時間 17 ms / 7,000 ms
コード長 6,925 bytes
コンパイル時間 874 ms
コンパイル使用メモリ 81,716 KB
実行使用メモリ 6,340 KB
最終ジャッジ日時 2023-08-17 05:59:16
合計ジャッジ時間 3,032 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
4,376 KB
testcase_01 AC 2 ms
4,380 KB
testcase_02 AC 2 ms
4,380 KB
testcase_03 AC 2 ms
4,380 KB
testcase_04 AC 1 ms
4,376 KB
testcase_05 AC 1 ms
4,376 KB
testcase_06 AC 2 ms
4,376 KB
testcase_07 AC 2 ms
4,376 KB
testcase_08 AC 2 ms
4,380 KB
testcase_09 AC 2 ms
4,376 KB
testcase_10 AC 1 ms
4,380 KB
testcase_11 AC 1 ms
4,376 KB
testcase_12 AC 2 ms
4,376 KB
testcase_13 AC 2 ms
4,380 KB
testcase_14 AC 2 ms
4,376 KB
testcase_15 AC 2 ms
4,380 KB
testcase_16 AC 2 ms
4,376 KB
testcase_17 AC 15 ms
5,844 KB
testcase_18 AC 13 ms
5,780 KB
testcase_19 AC 14 ms
5,840 KB
testcase_20 AC 14 ms
5,740 KB
testcase_21 AC 13 ms
5,840 KB
testcase_22 AC 14 ms
5,840 KB
testcase_23 AC 14 ms
5,852 KB
testcase_24 AC 17 ms
6,228 KB
testcase_25 AC 17 ms
6,136 KB
testcase_26 AC 16 ms
6,156 KB
testcase_27 AC 15 ms
6,092 KB
testcase_28 AC 16 ms
6,144 KB
testcase_29 AC 16 ms
6,196 KB
testcase_30 AC 15 ms
6,340 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <cassert>
#include <cmath>
#include <cstring>

#include <algorithm>
#include <iostream>
#include <vector>
#include <functional>

#define _rep(_1, _2, _3, _4, name, ...) name
#define rep2(i, n) rep3(i, 0, n)
#define rep3(i, a, b) rep4(i, a, b, 1)
#define rep4(i, a, b, c) for (int i = int(a); i < int(b); i += int(c))
#define rep(...) _rep(__VA_ARGS__, rep4, rep3, rep2, _)(__VA_ARGS__)

using namespace std;

using i64 = long long;
using u32 = unsigned;
using u64 = unsigned long long;
using f80 = long double;

namespace ntt {

using word_t = u64;
using dword_t = __uint128_t;

static constexpr word_t mul_inv(word_t n, int e=6, word_t x=1) {
  return e == 0 ? x : mul_inv(n, e-1, x*(2-x*n));
}

template <word_t mod, word_t prim_root>
class UnsafeMod {
private:
  static const int word_bits = 8 * sizeof(word_t);

public:
  static constexpr word_t inv = mul_inv(mod);
  static constexpr word_t r2 = -dword_t(mod) % mod;
  static constexpr int level = __builtin_ctzll(mod - 1);
  static_assert(inv * mod == 1, "invalid 1/M modulo 2^@.");

  UnsafeMod() {}
  UnsafeMod(word_t n) : x(init(n)) {};
  static word_t modulus() { return mod; }
  static word_t init(word_t w) { return reduce(dword_t(w) * r2); }
  static word_t reduce(const dword_t w) { 
    return word_t(w >> word_bits) 
         + mod - word_t((dword_t(word_t(w) * inv) * mod) >> word_bits); }
  static UnsafeMod omega() { return UnsafeMod(prim_root).pow((mod - 1) >> level); }
  UnsafeMod& operator += (UnsafeMod rhs) { x += rhs.x; return *this; }
  UnsafeMod& operator -= (UnsafeMod rhs) { x += 3 * mod - rhs.x; return *this; }
  UnsafeMod& operator *= (UnsafeMod rhs) { x = reduce(dword_t(x) * rhs.x); return *this; }
  UnsafeMod operator + (UnsafeMod rhs) const { return UnsafeMod(*this) += rhs; }
  UnsafeMod operator - (UnsafeMod rhs) const { return UnsafeMod(*this) -= rhs; }
  UnsafeMod operator * (UnsafeMod rhs) const { return UnsafeMod(*this) *= rhs; }
  word_t get() const { return reduce(x) % mod; }
  void set(word_t n) { x = n; }
  UnsafeMod pow(word_t e) const {
    UnsafeMod ret = UnsafeMod(1);
    for (UnsafeMod base = *this; e; e >>= 1, base *= base) if (e & 1) ret *= base;
    return ret;
  }
  UnsafeMod inverse() const { return pow(mod - 2); }
  friend ostream& operator << (ostream& os, const UnsafeMod& m) { return os << m.get(); }
  static void debug() { printf("%llu %llu %llu %llu\n", mod, inv, r2, omega().get()); }
  word_t x;
};

template <typename mod_t>
void transform(mod_t* A, int n, const mod_t* roots, const mod_t* iroots) {
  const int logn = __builtin_ctz(n), nh = n >> 1, lv = mod_t::level;
  const mod_t one = mod_t(1), imag = roots[lv - 2];

  mod_t dw[lv - 1]; dw[0] = roots[lv - 3];
  for (int i = 1; i < lv - 2; ++i) dw[i] = dw[i - 1] * iroots[lv - 1 - i] * roots[lv - 3 - i];
  dw[lv - 2] = dw[lv - 3] * iroots[1];

  if (logn & 1) for (int i = 0; i < nh; ++i) {
    mod_t a = A[i], b = A[i + nh];
    A[i] = a + b; A[i + nh] = a - b;
  }
  for (int e = logn & ~1; e >= 2; e -= 2) {
    const int m = 1 << e, m4 = m >> 2;
    mod_t w2 = one;
    for (int i = 0; i < n; i += m) {
      const mod_t w1 = w2 * w2, w3 = w1 * w2;
      for (int j = i; j < i + m4; ++j) {
        mod_t a0 = A[j + m4 * 0] * one, a1 = A[j + m4 * 1] * w2;
        mod_t a2 = A[j + m4 * 2] * w1,  a3 = A[j + m4 * 3] * w3;
        mod_t t02p = a0 + a2, t13p = a1 + a3;
        mod_t t02m = a0 - a2, t13m = (a1 - a3) * imag;
        A[j + m4 * 0] = t02p + t13p; A[j + m4 * 1] = t02p - t13p;
        A[j + m4 * 2] = t02m + t13m; A[j + m4 * 3] = t02m - t13m;
      }
      w2 *= dw[__builtin_ctz(~(i >> e))];
    }
  }
}

template <typename mod_t>
void itransform(mod_t* A, int n, const mod_t* roots, const mod_t* iroots) {
  const int logn = __builtin_ctz(n), nh = n >> 1, lv = mod_t::level;
  const mod_t one = mod_t(1), imag = iroots[lv - 2];

  mod_t dw[lv - 1]; dw[0] = iroots[lv - 3];
  for (int i = 1; i < lv - 2; ++i) dw[i] = dw[i - 1] * roots[lv - 1 - i] * iroots[lv - 3 - i];
  dw[lv - 2] = dw[lv - 3] * roots[1];

  for (int e = 2; e <= logn; e += 2) {
    const int m = 1 << e, m4 = m >> 2;
    mod_t w2 = one;
    for (int i = 0; i < n; i += m) {
      const mod_t w1 = w2 * w2, w3 = w1 * w2;
      for (int j = i; j < i + m4; ++j) {
        mod_t a0 = A[j + m4 * 0], a1 = A[j + m4 * 1];
        mod_t a2 = A[j + m4 * 2], a3 = A[j + m4 * 3];
        mod_t t01p = a0 + a1, t23p = a2 + a3;
        mod_t t01m = a0 - a1, t23m = (a2 - a3) * imag;
        A[j + m4 * 0] = (t01p + t23p) * one; A[j + m4 * 2] = (t01p - t23p) * w1;
        A[j + m4 * 1] = (t01m + t23m) * w2;  A[j + m4 * 3] = (t01m - t23m) * w3;
      }
      w2 *= dw[__builtin_ctz(~(i >> e))];
    }
  }
  if (logn & 1) for (int i = 0; i < nh; ++i) {
    mod_t a = A[i], b = A[i + nh];
    A[i] = a + b; A[i + nh] = a - b;
  }
}

template <typename mod_t>
void convolve(mod_t* A, int s1, mod_t* B, int s2, bool cyclic=false) {
  const int s = cyclic ? max(s1, s2) : s1 + s2 - 1;
  const int size = 1 << (31 - __builtin_clz(2 * s - 1));
  assert(size <= (i64(1) << mod_t::level));
  
  mod_t roots[mod_t::level], iroots[mod_t::level];
  roots[0] = mod_t::omega();
  for (int i = 1; i < mod_t::level; ++i) roots[i] = roots[i - 1] * roots[i - 1];
  iroots[0] = roots[0].inverse();
  for (int i = 1; i < mod_t::level; ++i) iroots[i] = iroots[i - 1] * iroots[i - 1];

  fill(A + s1, A + size, 0); transform(A, size, roots, iroots);
  const mod_t inv = mod_t(size).inverse();
  if (A == B && s1 == s2) {
    for (int i = 0; i < size; ++i) A[i] *= A[i] * inv;
  } else {
    fill(B + s2, B + size, 0); transform(B, size, roots, iroots);
    for (int i = 0; i < size; ++i) A[i] *= B[i] * inv;
  }
  itransform(A, size, roots, iroots);
}

const int alloc_size = 1 << 17;
using m64 = ntt::UnsafeMod<1128298388379402241, 23>; // <= 1.14e18 (sub.D = 3)
m64 f[alloc_size], g[alloc_size];

} // namespace ntt

#define getchar getchar_unlocked
#define putchar putchar_unlocked

int get_int() {
  int n; char c;
  while ((c = getchar()) < '0');
  n = c - '0';
  while ((c = getchar()) >= '0') n = n * 10 + c - '0';
  return n;
}

void put_u32(u32 n) {
  char strs[11];
  int i = 0;
  do {
    strs[i++] = n % 10, n /= 10;
  } while (n);
  while (i) putchar('0' + strs[--i]);
  putchar('\n');
}

int main() {
  using namespace ntt;
  int L = get_int(), M = get_int(), N = get_int();

  const u32 s = 17, mask = 1 << s;
  auto one = m64(1), one17 = m64(mask);
  rep(i, L) {
    int n = get_int();
    f[n >> 1] += (n & 1) ? one17 : one;
  }
  rep(i, M) {
    int n = N - get_int();
    g[n >> 1] += (n & 1) ? one17 : one;
  }
  ntt::convolve(f, N / 2 + 1, g, N / 2 + 1);

  int Q = get_int();

  static int res[200010];
  u64 carry = 0;
  rep(i, 0, (N + Q + 1) / 2) {
    u64 n = f[i].get() + carry;
    res[2 * i + 0] = n & (mask - 1);
    res[2 * i + 1] = (n >> s) & (mask - 1);
    carry = n >> (2 * s);
  }
  rep(i, N, N + Q) put_u32(res[i]);
  return 0;
}
0