結果

問題 No.214 素数サイコロと合成数サイコロ (3-Medium)
ユーザー Min_25Min_25
提出日時 2015-05-23 15:44:17
言語 C++11
(gcc 11.4.0)
結果
AC  
実行時間 213 ms / 3,000 ms
コード長 12,206 bytes
コンパイル時間 1,445 ms
コンパイル使用メモリ 109,812 KB
実行使用メモリ 4,380 KB
最終ジャッジ日時 2023-09-20 10:53:08
合計ジャッジ時間 2,739 ms
ジャッジサーバーID
(参考情報)
judge15 / judge14
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 209 ms
4,380 KB
testcase_01 AC 213 ms
4,380 KB
testcase_02 AC 213 ms
4,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <cmath>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include <cassert>

#include <iostream>
#include <utility>
#include <algorithm>
#include <queue>
#include <functional>
#include <vector>
#include <map>
#include <set>
#include <complex>

#define getchar getchar_unlocked
#define putchar putchar_unlocked

using namespace std;

typedef long long int64;
typedef long long unsigned uint64;
typedef long double float80;
typedef unsigned short uint16;
typedef unsigned uint;
typedef unsigned char uint8;

uint get_uint() {
  uint n;
  int c;
  while( (c = getchar()) < '0') {
    ;
  }
  n = c - '0';
  while( (c = getchar()) >= '0') {
    n = n * 10 + (c - '0');
  }
  return n;
}

void put_uint(uint n) {
  uint8 stack[30];
  int top = 0;
  do {
    stack[top++] = n % 10 + '0';
    n /= 10;
  } while(n != 0);
  while(top > 0) {
    putchar(stack[--top]);
  }
  putchar('\n');
}


// mod < 2^63
template <uint64 mod, uint64 z, uint period, uint64 n_prime, uint64 r2>
class Mod64 {
private:
  typedef __uint128_t uint128;

public:
  Mod64() {}
  Mod64(uint64 v) {
    n_ = montgomery_init(v);
  }
  static Mod64 root2pow(uint i) {
    return Mod64(z).pow(1ull << (period - i));
  }
  Mod64 operator+ (Mod64 rhs) const {
    Mod64 ret;
    uint64 t = this->n_ + rhs.n_;
    ret.set_value(t >= mod ? t - mod : t);
    return ret;
  }
  Mod64 operator- (Mod64 rhs) const {
    Mod64 ret;
    uint64 t = this->n_ - rhs.n_;
    ret.set_value(int64(t) < 0 ? t + mod : t);
    return ret;
  }
  Mod64 operator* (Mod64 rhs) const {
    Mod64 ret;
    uint64 t = montgomery_reduction(uint128(this->n_) * rhs.n_);
    ret.set_value(t);
    return ret;
  }
  Mod64 operator+= (Mod64 rhs) {
    return *this = *this + rhs;
  }
  Mod64 operator-= (Mod64 rhs) {
    return *this = *this - rhs;
  }
  Mod64 operator*= (Mod64 rhs) {
    return *this = *this * rhs;
  }
  uint64 get_raw_value() const {
    return this->n_;
  }
  uint64 get_value() const {
    return montgomery_reduction(this->n_);
  }
  void set_value(uint64 val) {
    this->n_ = val;
  }
  Mod64 pow(uint64 exp) const {
    Mod64 base = *this;
    Mod64 ret = Mod64(1);
    while (exp) {
      if (exp & 1) {
        ret *= base;
      }
      exp >>= 1;
      base *= base;
    }
    return ret;
  }
  Mod64 inverse() const {
    return pow(mod - 2);
  }

private:
  uint64 n_;

  static uint64 montgomery_init(uint64 w) {
    return montgomery_reduction(uint128(w) * r2);
  }
  static uint64 montgomery_reduction(const uint128 w) {
    uint64 x = uint64(w) * n_prime;
    uint128 y = uint128(x) * mod + w;
    uint64 ret = y >> 64;
    if (ret >= mod) {
      ret -= mod;
    }
    return ret;
  }
};

template <typename mod_t>
class NTT {

public:
  static void auto_convolute(mod_t* poly_a, uint size_a) {
    uint size = size_a;

    uint ntt_size = 1;
    uint ldn = 0;
    while (ntt_size < size) {
      ntt_size <<= 1;
      ldn++;
    }
    ntt_size <<= 1;
    ++ldn;

    fill(poly_a + size_a, poly_a + ntt_size, 0);

    ntt_dit4(poly_a, ldn, 1);
    for (uint i = 0; i < ntt_size; ++i) {
      poly_a[i] *= poly_a[i];
    }
    ntt_dit4(poly_a, ldn, -1);

    mod_t inv = mod_t(ntt_size).inverse();
    for (uint i = 0; i < 2 * size_a - 1; ++i) {
      poly_a[i] *= inv;
    }
  }

  static void convolute(mod_t* poly_a, uint size_a, mod_t* poly_b, uint size_b) {
    uint size = max(size_a, size_b);

    uint ntt_size = 1;
    uint ldn = 0;
    while (ntt_size < size) {
      ntt_size <<= 1;
      ldn++;
    }
    ntt_size <<= 1;
    ++ldn;

    fill(poly_a + size_a, poly_a + ntt_size, 0);
    fill(poly_b + size_b, poly_b + ntt_size, 0);

    ntt_dit4(poly_a, ldn, 1);
    ntt_dit4(poly_b, ldn, 1);
    for (uint i = 0; i < ntt_size; ++i) {
      poly_a[i] *= poly_b[i];
    }
    ntt_dit4(poly_a, ldn, -1);

    mod_t inv = mod_t(ntt_size).inverse();
    for (uint i = 0; i < size_a + size_b - 1; ++i) {
      poly_a[i] *= inv;
    }
  }

  static void ntt_dit4(mod_t* f, uint ldn, int sign) {
    revbin_permute(f, 1u << ldn);
    ntt_dit4_core(f, ldn, sign);
  }

private:
  static inline void sumdiff(mod_t& a, mod_t& b) {
    mod_t t = a - b;
    a += b;
    b = t;
  }

  static void revbin_permute(mod_t* A, uint n) {
    if (n <= 2) {
      return;
    }
    uint r = 0;
    uint nh = n >> 1;
    for (uint x = 1; x < n; ++x) {
      uint h = nh;
      while (! ((r ^= h) & h)) {
        h >>= 1;
      }
      if (r > x) {
        swap(A[x], A[r]);
      }
    }
  }

  static void ntt_dit4_core(mod_t *f, uint ldn, int sign) {
    const uint LX = 2;
    const uint n = 1u << ldn;
    
    if (ldn & 1) {
      for (uint i = 0; i < n; i += 2) {
        sumdiff(f[i], f[i+1]);
      }
    }

    mod_t imag = mod_t::root2pow(2);
    if (sign < 0) {
      imag = imag.inverse();
    }

    uint ldm = LX + (ldn & 1);

    mod_t one = mod_t(1);
    for (; ldm <= ldn; ldm += LX) {
      const uint m = 1u << ldm;
      const uint m4 = m >> LX;

      mod_t dw = mod_t::root2pow(ldm);
      if (sign < 0) {
        dw = dw.inverse();
      }
      mod_t w = one;
      mod_t w2 = w;
      mod_t w3 = w;

      for (uint j = 0; j < m4; ++j) {
        for (uint r = 0, i = j + r; r < n; r += m, i += m) {
          mod_t a0 = f[i + m4 * 0];
          mod_t a2 = f[i + m4 * 1] * w2;
          mod_t a1 = f[i + m4 * 2] * w;
          mod_t a3 = f[i + m4 * 3] * w3;

          mod_t t02 = a0 + a2;
          mod_t t13 = a1 + a3;

          f[i + m4 * 0] = t02 + t13;
          f[i + m4 * 2] = t02 - t13;

          t02 = a0 - a2;
          t13 = a1 - a3;
          t13 *= imag;

          f[i + m4 * 1] = t02 + t13;
          f[i + m4 * 3] = t02 - t13;
        }

        w *= dw;
        w2 = w * w;
        w3 = w * w2;
      }
    }
  }
};

const uint64 P1 = 0x3f91300000000001ull;
const uint64 P2 = 0x3f93300000000001ull;
const uint64 P1_INV_MOD_P2 = 0x19D3CB8000001FCAull;

typedef Mod64<P1, 0x1941B388165C78EBull, 44, 
  0x3f912fffffffffffull, 0x0298CD3E4612D42Aull> mod64_t1;

typedef Mod64<P2, 0x394ba9c52fec1825ull, 44,
  0x3f932fffffffffffull, 0x30da2c2b74e27a1full> mod64_t2;

// ------------------------------------------------------------------------------

const uint mod = 1e9 + 7;
const uint64 lim = (0xFFFFFFFFFFFFFFFFull / mod - mod) * mod; 
const uint MAX_PC = 300;
const uint NTT_BUFF_SIZE = (1 << 14) + 1;

uint dp[MAX_PC + 1][MAX_PC * 13 + 1];
uint poly_P[MAX_PC * 13 + 1];
uint poly_C[MAX_PC * 12 + 1];
uint poly[MAX_PC * 25 + 1];
uint poly_inv[MAX_PC * 25 + 100];
uint sums[MAX_PC * 25 + 1];
uint counts[MAX_PC * 25 + 1];

uint rem[2 * MAX_PC * 25 + 100];
uint64 temp[2 * MAX_PC * 25 + 100];

mod64_t1 A1[NTT_BUFF_SIZE], B1[NTT_BUFF_SIZE];
mod64_t2 A2[NTT_BUFF_SIZE], B2[NTT_BUFF_SIZE];


void _poly_mul_restore(const uint size, mod64_t1* a1, mod64_t2* a2, uint* res) {
  for(uint i = 0; i < size; ++i) {
    uint64 x1 = a1[i].get_value();
    uint64 x2 = a2[i].get_value();
    if(x1 != x2) {
      uint64 dx = x2 - x1;
      if(int64(dx) < 0) {
        dx += P2; // P1 < P2
      }
      dx = (mod64_t2(dx) * mod64_t2(P1_INV_MOD_P2)).get_value();
      x1 = (x1 + dx % mod * (P1 % mod));
    }
    res[i] = x1 % mod;
  }
}

uint poly_square(const uint* p1, uint s1, uint* res) {
  for (uint i = 0; i < s1; ++i) {
    A1[i] = mod64_t1(p1[i]);
    A2[i] = mod64_t2(p1[i]);
  }
  NTT<mod64_t1>::auto_convolute(A1, s1);
  NTT<mod64_t2>::auto_convolute(A2, s1);
  _poly_mul_restore(s1 * 2 - 1, A1, A2, res);
  return s1 * 2 - 1;
}

uint poly_mul(const uint* p1, uint s1, const uint* p2, uint s2, uint* res) {
  for (uint i = 0; i < s1; ++i) {
    A1[i] = mod64_t1(p1[i]);
    A2[i] = mod64_t2(p1[i]);
  }
  for (uint i = 0; i < s2; ++i) {
    B1[i] = mod64_t1(p2[i]);
    B2[i] = mod64_t2(p2[i]);
  }
  NTT<mod64_t1>::convolute(A1, s1, B1, s2);
  NTT<mod64_t2>::convolute(A2, s1, B2, s2);
  _poly_mul_restore(s1 + s2 - 1, A1, A2, res);
  return s1 + s2 - 1;
}

uint _poly_rem(const uint* p1, uint s1, const uint* p2, uint s2, uint ofs, uint* ret) {
  if (s1 < s2) {
    copy(p1, p1 + s1, ret);
    return s1;
  }
  copy(p1, p1 + s1, temp);
  uint dif = s1 - s2 + 1;
  for (uint i = 0; i < dif; ++i) {
    uint64 coef = temp[i] % mod;
    if (!coef) {
      continue;
    }
    coef = mod - coef;
    for (uint j = ofs; j < s2; ++j) {
      temp[i + j] += coef * p2[j];
      if (temp[i + j] >= lim) {
        temp[i + j] -= lim;
      }
    }
  }
  for (uint i = dif; i < s1; ++i) {
    ret[i - dif] = temp[i] % mod;
  }
  return s2 - 1;
}

uint poly_tmp[2 * MAX_PC * 25 + 100];
uint poly_tmp2[2 * MAX_PC * 25 + 100];
uint poly_q[2 * MAX_PC * 25 + 100];

uint poly_div(const uint* poly1, uint s1, const uint* , uint s2, 
    const uint *inv, uint inv_size, uint *res) {
  assert(s1 >= s2);
  uint needed_size = s1 - s2 + 1;
  assert(inv_size >= needed_size);
  poly_mul(poly1, needed_size, inv, needed_size, res);
  return needed_size;
}

uint poly_rem(const uint* poly1, uint s1, const uint* poly2, uint s2, uint* inv, uint inv_size, uint* res) {
  if (s1 < s2) {
    copy(poly1, poly1 + s1, res);
    return s1;
  }
  uint q_size = poly_div(poly1, s1, poly2, s2, inv, inv_size, poly_q);
  poly_mul(poly_q, q_size, poly2, s2, poly_q);
  uint ofs = s1 - s2 + 1;
  for (uint i = 0; i < s2 - 1; ++i) {
    res[i] = (poly1[ofs + i] + mod - poly_q[ofs + i]) % mod;
  }
  return s2 - 1;
}

uint poly_inverse(const uint *poly, uint size, uint* inv, uint needed_size) {
  uint degs[100];
  uint deg = needed_size - 1;
  uint deg_pos = 0;
  while (deg) {
    degs[deg_pos++] = deg;
    deg >>= 1;
  }
  if (size < needed_size) {
    copy(poly, poly + size, poly_tmp);
    fill(poly_tmp + size, poly_tmp + needed_size, 0);
  } else {
    copy(poly, poly + needed_size, poly_tmp);
  }

  uint inv_size = 1;
  inv[0] = 1;
  while (deg_pos) {
    uint t = degs[--deg_pos];
    uint added = t + 1 - inv_size;
    poly_mul(poly_tmp, t + 1, inv, inv_size, poly_tmp2);
    poly_mul(poly_tmp2 + inv_size, added, inv, added, poly_tmp2);
    for (uint i = inv_size; i < inv_size + added; ++i) {
      inv[i] = (mod - poly_tmp2[i - inv_size] % mod) % mod;
    }
    inv_size += added;
  }
  return inv_size;
}

uint poly_power_rem(uint64 e, const uint* divisor, uint div_size, uint* ret) {
  uint64 mask = 1;
  while (mask <= e) {
    mask <<= 1;
  }
  mask >>= 1;

  uint ret_size = 0;
  ret[ret_size++] = 1;

  const uint inv_size = poly_inverse(divisor, div_size, poly_inv, div_size);

  while (mask) {
    if (e & mask) {
      ret[ret_size++] = 0;
    }
    mask >>= 1;
    if (mask == 0) {
      break;
    }
    ret_size = poly_square(ret, ret_size, ret);
    ret_size = poly_rem(ret, ret_size, divisor, div_size, poly_inv, inv_size, ret);
  }
  if (ret_size >= div_size) {
    ret_size = poly_rem(ret, ret_size, divisor, div_size, poly_inv, inv_size, ret);
  }
  return ret_size;
}


void conv(uint* poly, uint size) {
  poly[0] = 1;
  for (uint i = 1; i < size; ++i) {
    if (poly[i]) {
      poly[i] = mod - poly[i];
    }
  }
}

inline void add_mod(uint& a, uint b) {
  a += b;
  if (a >= mod) {
    a -= mod;
  }
}

void init_poly(const uint* dice, const uint T, uint* result) {
  for (uint i = 1; i <= T; ++i) {
    fill(dp[i], dp[i] + dice[5] * i + 1, 0);
  }
  dp[0][0] = 1;
  for (uint di = 0; di < 6; ++di) {
    const uint d = dice[di];
    for (uint t = 0; t < T; ++t) {
      for (uint i = t * dice[0]; i <= t * dice[di]; ++i) {
        if (dp[t][i]) {
          add_mod(dp[t + 1][i + d], dp[t][i]);
        }
      }
    }
  }
  for (uint i = 0; i <= dice[5] * T; ++i) {
    result[i] = dp[T][i];
  }
}

void solve() {
  const uint Ps[] = {2, 3, 5, 7, 11, 13};
  const uint Cs[] = {4, 6, 8, 9, 10, 12};
  
  uint64 N;
  uint P, C;

  while (~scanf("%llu %u %u", &N, &P, &C)) {
    init_poly(Ps, P, poly_P);
    init_poly(Cs, C, poly_C);
    uint size_p = 13 * P + 1;
    uint size_c = 12 * C + 1;

    poly_mul(poly_P, size_p, poly_C, size_c, poly);
    const uint poly_size = 13 * P + 12 * C + 1;
    conv(poly, poly_size);
    
    uint rem_size = poly_power_rem(N + poly_size - 2, poly, poly_size, rem);
    uint64 ans = 0;
    for (uint i = 0; i < rem_size; ++i) {
      ans += rem[i];
    }
    printf("%llu\n", ans % mod);
  }
}

int main() {
  solve();
  return 0;
}
0