結果

問題 No.206 数の積集合を求めるクエリ
ユーザー Min_25Min_25
提出日時 2015-05-12 08:57:09
言語 C++11
(gcc 11.4.0)
結果
AC  
実行時間 45 ms / 7,000 ms
コード長 6,181 bytes
コンパイル時間 1,352 ms
コンパイル使用メモリ 85,120 KB
実行使用メモリ 6,332 KB
最終ジャッジ日時 2023-09-20 03:40:20
合計ジャッジ時間 4,277 ms
ジャッジサーバーID
(参考情報)
judge15 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
4,380 KB
testcase_01 AC 1 ms
4,376 KB
testcase_02 AC 2 ms
4,380 KB
testcase_03 AC 2 ms
4,380 KB
testcase_04 AC 2 ms
4,380 KB
testcase_05 AC 1 ms
4,380 KB
testcase_06 AC 2 ms
4,380 KB
testcase_07 AC 3 ms
4,384 KB
testcase_08 AC 3 ms
4,384 KB
testcase_09 AC 2 ms
4,376 KB
testcase_10 AC 2 ms
4,380 KB
testcase_11 AC 2 ms
4,376 KB
testcase_12 AC 2 ms
4,380 KB
testcase_13 AC 3 ms
4,376 KB
testcase_14 AC 3 ms
4,376 KB
testcase_15 AC 2 ms
4,376 KB
testcase_16 AC 2 ms
4,380 KB
testcase_17 AC 43 ms
6,280 KB
testcase_18 AC 42 ms
6,184 KB
testcase_19 AC 43 ms
6,256 KB
testcase_20 AC 34 ms
6,168 KB
testcase_21 AC 43 ms
6,304 KB
testcase_22 AC 42 ms
6,140 KB
testcase_23 AC 43 ms
6,148 KB
testcase_24 AC 45 ms
6,252 KB
testcase_25 AC 45 ms
6,332 KB
testcase_26 AC 44 ms
6,164 KB
testcase_27 AC 43 ms
6,168 KB
testcase_28 AC 45 ms
6,160 KB
testcase_29 AC 45 ms
6,160 KB
testcase_30 AC 44 ms
6,140 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include <cassert>

#include <iostream>
#include <utility>
#include <algorithm>
#include <queue>
#include <functional>
#include <vector>
#include <map>
#include <set>

#define getchar getchar_unlocked
#define putchar putchar_unlocked

using namespace std;

typedef long long int64;
typedef long long unsigned uint64;
typedef long double float80;
typedef unsigned short uint16;
typedef unsigned uint;
typedef unsigned char uint8;

uint get_uint() {
  uint n;
  int c;
  while( (c = getchar()) < '0') {
    ;
  }
  n = c - '0';
  while( (c = getchar()) >= '0') {
    n = n * 10 + (c - '0');
  }
  return n;
}

void put_uint(uint n) {
  uint8 stack[30];
  int top = 0;
  do {
    stack[top++] = n % 10 + '0';
    n /= 10;
  } while(n != 0);
  while(top > 0) {
    putchar(stack[--top]);
  }
  putchar('\n');
}

// mod < 2^63
template <uint64 mod, uint64 z, uint period, uint64 n_prime, uint64 r2>
class Mod64 {
private:
  typedef __uint128_t uint128;

public:
  Mod64() {}
  Mod64(uint64 v) {
    n_ = montgomery_init(v);
  }
  static Mod64 root2pow(uint i) {
    return Mod64(z).pow(1ull << (period - i));
  }
  Mod64 operator+ (Mod64 rhs) const {
    Mod64 ret;
    uint64 t = this->n_ + rhs.n_;
    ret.set_value(t >= mod ? t - mod : t);
    return ret;
  }
  Mod64 operator- (Mod64 rhs) const {
    Mod64 ret;
    uint64 t = this->n_ - rhs.n_;
    ret.set_value(int64(t) < 0 ? t + mod : t);
    return ret;
  }
  Mod64 operator* (Mod64 rhs) const {
    Mod64 ret;
    uint64 t = montgomery_reduction(uint128(this->n_) * rhs.n_);
    ret.set_value(t);
    return ret;
  }
  Mod64 operator+= (Mod64 rhs) {
    return *this = *this + rhs;
  }
  Mod64 operator-= (Mod64 rhs) {
    return *this = *this - rhs;
  }
  Mod64 operator*= (Mod64 rhs) {
    return *this = *this * rhs;
  }
  uint64 get_raw_value() const {
    return this->n_;
  }
  uint64 get_value() const {
    return montgomery_reduction(this->n_);
  }
  void set_value(uint64 val) {
    this->n_ = val;
  }
  Mod64 pow(uint64 exp) const {
    Mod64 base = *this;
    Mod64 ret = Mod64(1);
    while (exp) {
      if (exp & 1) {
        ret *= base;
      }
      exp >>= 1;
      base *= base;
    }
    return ret;
  }
  Mod64 inverse() const {
    return pow(mod - 2);
  }

private:
  uint64 n_;

  static uint64 montgomery_init(uint64 w) {
    return montgomery_reduction(uint128(w) * r2);
  }
  static uint64 montgomery_reduction(const uint128 w) {
    uint64 x = uint64(w) * n_prime;
    uint128 y = uint128(x) * mod + w;
    uint64 ret = y >> 64;
    if (ret >= mod) {
      ret -= mod;
    }
    return ret;
  }
};

template <typename mod_t>
class NTT {

public:
  static void convolute(mod_t* poly_a, uint size_a, mod_t* poly_b, uint size_b) {
    uint size = max(size_a, size_b);
    uint ntt_size = 1;
    uint ldn = 0;
    while (ntt_size < size) {
      ntt_size <<= 1;
      ldn++;
    }
    ntt_size <<= 1;
    ++ldn;

    std::fill(poly_a + size_a, poly_a + ntt_size, 0);
    std::fill(poly_b + size_b, poly_b + ntt_size, 0);

    ntt_dif4(poly_a, ldn, 1);
    ntt_dif4(poly_b, ldn, 1);
    for (uint i = 0; i < ntt_size; ++i) {
      poly_a[i] *= poly_b[i];
    }
    ntt_dif4(poly_a, ldn, -1);

    mod_t inv = mod_t(ntt_size).inverse();
    for (uint i = 0; i < size_a + size_b - 1; ++i) {
      poly_a[i] *= inv;
    }
  }

  static void ntt_dif4(mod_t* f, uint ldn, int sign) {
    ntt_dif4_core(f, ldn, sign);
    revbin_permute(f, 1u << ldn);
  }

private:
  static inline void sumdiff(mod_t& a, mod_t& b) {
    mod_t t = a - b;
    a += b;
    b = t;
  }

  static void revbin_permute(mod_t* A, uint n) {
    if(n <= 2) {
      return;
    }
    uint r = 0;
    uint nh = n >> 1;
    for(uint x = 1; x < n; ++x) {
      uint h = nh;
      while(! ((r ^= h) & h)) {
        h >>= 1;
      }
      if(r > x) {
        swap(A[x], A[r]);
      }
    }
  }

  static void ntt_dif4_core(mod_t *f, uint ldn, int sign) {
    const uint LX = 2;
    const uint n = 1u << ldn;

    mod_t imag = mod_t::root2pow(2);
    if (sign < 0) {
      imag = imag.inverse();
    }

    mod_t one = mod_t(1);
    for (uint ldm = ldn; ldm >= LX; ldm -= LX) {
      const uint m = 1u << ldm;
      const uint m4 = m >> LX;

      mod_t dw = mod_t::root2pow(ldm);
      if(sign < 0) {
        dw = dw.inverse();
      }
      mod_t w = one;
      mod_t w2 = w;
      mod_t w3 = w;

      for (uint j = 0; j < m4; ++j) {
        for (uint r = 0, i = j + r; r < n; r += m, i += m) {
          mod_t a0 = f[i + m4 * 0];
          mod_t a1 = f[i + m4 * 1];
          mod_t a2 = f[i + m4 * 2];
          mod_t a3 = f[i + m4 * 3];

          mod_t t02 = a0 + a2;
          mod_t t13 = a1 + a3;

          f[i + m4 * 0] = t02 + t13;
          f[i + m4 * 1] = (t02 - t13) * w2;

          t02 = a0 - a2;
          t13 = a1 - a3;
          t13 *= imag;

          f[i + m4 * 2] = (t02 + t13) * w;
          f[i + m4 * 3] = (t02 - t13) * w3;
        }

        w *= dw;
        w2 = w * w;
        w3 = w * w2;
      }
    }

    if (ldn & 1) {
      for (uint i = 0; i < n; i += 2) {
        sumdiff(f[i], f[i+1]);
      }
    }
  }
};

// constexpr ... 
typedef Mod64<0x3f91300000000001ull, 0x1941B388165C78EBull, 44, 
  0x3f912fffffffffffull, 0x0298CD3E4612D42Aull> mod64_t;

mod64_t A[1 << 17];
mod64_t B[1 << 17];

uint res[200011];

const uint BITS = 17;
const uint MASK = (1 << BITS) - 1;

void solve() {
  uint L = get_uint();
  uint M = get_uint();
  uint N = get_uint();

  mod64_t one = mod64_t(1);
  mod64_t two17 = mod64_t(1 << BITS);
  for (uint i = 0; i < L; ++i) {
    uint n = get_uint();
    A[n / 2] += (n & 1 ? two17 : one);
  }
  for (uint i = 0; i < M; ++i) {
    uint n = N - get_uint();
    B[n / 2] += (n & 1 ? two17 : one);
  }

  NTT<mod64_t>::convolute(A, N / 2 + 1, B, N / 2 + 1);

  uint Q = get_uint();

  uint64 carry = 0;
  for (uint i = 0; i < N + 1; ++i) {
    uint64 n = A[i].get_value() + carry;
    res[2 * i + 0] = n & MASK;
    res[2 * i + 1] = (n >> BITS) & MASK;
    carry = n >> (BITS * 2);
  }

  for (uint i = N; i < N + Q; ++i) {
    put_uint(res[i]);
  }
}

int main() {
  solve();
  return 0;
}
0