結果
問題 | No.214 素数サイコロと合成数サイコロ (3-Medium) |
ユーザー | Min_25 |
提出日時 | 2015-05-23 15:44:17 |
言語 | C++11 (gcc 11.4.0) |
結果 |
AC
|
実行時間 | 199 ms / 3,000 ms |
コード長 | 12,206 bytes |
コンパイル時間 | 1,106 ms |
コンパイル使用メモリ | 95,180 KB |
実行使用メモリ | 6,944 KB |
最終ジャッジ日時 | 2024-07-06 06:20:38 |
合計ジャッジ時間 | 2,009 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 175 ms
6,816 KB |
testcase_01 | AC | 179 ms
6,944 KB |
testcase_02 | AC | 199 ms
6,940 KB |
ソースコード
#include <cstdio> #include <cmath> #include <cstring> #include <cstdlib> #include <ctime> #include <cassert> #include <iostream> #include <utility> #include <algorithm> #include <queue> #include <functional> #include <vector> #include <map> #include <set> #include <complex> #define getchar getchar_unlocked #define putchar putchar_unlocked using namespace std; typedef long long int64; typedef long long unsigned uint64; typedef long double float80; typedef unsigned short uint16; typedef unsigned uint; typedef unsigned char uint8; uint get_uint() { uint n; int c; while( (c = getchar()) < '0') { ; } n = c - '0'; while( (c = getchar()) >= '0') { n = n * 10 + (c - '0'); } return n; } void put_uint(uint n) { uint8 stack[30]; int top = 0; do { stack[top++] = n % 10 + '0'; n /= 10; } while(n != 0); while(top > 0) { putchar(stack[--top]); } putchar('\n'); } // mod < 2^63 template <uint64 mod, uint64 z, uint period, uint64 n_prime, uint64 r2> class Mod64 { private: typedef __uint128_t uint128; public: Mod64() {} Mod64(uint64 v) { n_ = montgomery_init(v); } static Mod64 root2pow(uint i) { return Mod64(z).pow(1ull << (period - i)); } Mod64 operator+ (Mod64 rhs) const { Mod64 ret; uint64 t = this->n_ + rhs.n_; ret.set_value(t >= mod ? t - mod : t); return ret; } Mod64 operator- (Mod64 rhs) const { Mod64 ret; uint64 t = this->n_ - rhs.n_; ret.set_value(int64(t) < 0 ? t + mod : t); return ret; } Mod64 operator* (Mod64 rhs) const { Mod64 ret; uint64 t = montgomery_reduction(uint128(this->n_) * rhs.n_); ret.set_value(t); return ret; } Mod64 operator+= (Mod64 rhs) { return *this = *this + rhs; } Mod64 operator-= (Mod64 rhs) { return *this = *this - rhs; } Mod64 operator*= (Mod64 rhs) { return *this = *this * rhs; } uint64 get_raw_value() const { return this->n_; } uint64 get_value() const { return montgomery_reduction(this->n_); } void set_value(uint64 val) { this->n_ = val; } Mod64 pow(uint64 exp) const { Mod64 base = *this; Mod64 ret = Mod64(1); while (exp) { if (exp & 1) { ret *= base; } exp >>= 1; base *= base; } return ret; } Mod64 inverse() const { return pow(mod - 2); } private: uint64 n_; static uint64 montgomery_init(uint64 w) { return montgomery_reduction(uint128(w) * r2); } static uint64 montgomery_reduction(const uint128 w) { uint64 x = uint64(w) * n_prime; uint128 y = uint128(x) * mod + w; uint64 ret = y >> 64; if (ret >= mod) { ret -= mod; } return ret; } }; template <typename mod_t> class NTT { public: static void auto_convolute(mod_t* poly_a, uint size_a) { uint size = size_a; uint ntt_size = 1; uint ldn = 0; while (ntt_size < size) { ntt_size <<= 1; ldn++; } ntt_size <<= 1; ++ldn; fill(poly_a + size_a, poly_a + ntt_size, 0); ntt_dit4(poly_a, ldn, 1); for (uint i = 0; i < ntt_size; ++i) { poly_a[i] *= poly_a[i]; } ntt_dit4(poly_a, ldn, -1); mod_t inv = mod_t(ntt_size).inverse(); for (uint i = 0; i < 2 * size_a - 1; ++i) { poly_a[i] *= inv; } } static void convolute(mod_t* poly_a, uint size_a, mod_t* poly_b, uint size_b) { uint size = max(size_a, size_b); uint ntt_size = 1; uint ldn = 0; while (ntt_size < size) { ntt_size <<= 1; ldn++; } ntt_size <<= 1; ++ldn; fill(poly_a + size_a, poly_a + ntt_size, 0); fill(poly_b + size_b, poly_b + ntt_size, 0); ntt_dit4(poly_a, ldn, 1); ntt_dit4(poly_b, ldn, 1); for (uint i = 0; i < ntt_size; ++i) { poly_a[i] *= poly_b[i]; } ntt_dit4(poly_a, ldn, -1); mod_t inv = mod_t(ntt_size).inverse(); for (uint i = 0; i < size_a + size_b - 1; ++i) { poly_a[i] *= inv; } } static void ntt_dit4(mod_t* f, uint ldn, int sign) { revbin_permute(f, 1u << ldn); ntt_dit4_core(f, ldn, sign); } private: static inline void sumdiff(mod_t& a, mod_t& b) { mod_t t = a - b; a += b; b = t; } static void revbin_permute(mod_t* A, uint n) { if (n <= 2) { return; } uint r = 0; uint nh = n >> 1; for (uint x = 1; x < n; ++x) { uint h = nh; while (! ((r ^= h) & h)) { h >>= 1; } if (r > x) { swap(A[x], A[r]); } } } static void ntt_dit4_core(mod_t *f, uint ldn, int sign) { const uint LX = 2; const uint n = 1u << ldn; if (ldn & 1) { for (uint i = 0; i < n; i += 2) { sumdiff(f[i], f[i+1]); } } mod_t imag = mod_t::root2pow(2); if (sign < 0) { imag = imag.inverse(); } uint ldm = LX + (ldn & 1); mod_t one = mod_t(1); for (; ldm <= ldn; ldm += LX) { const uint m = 1u << ldm; const uint m4 = m >> LX; mod_t dw = mod_t::root2pow(ldm); if (sign < 0) { dw = dw.inverse(); } mod_t w = one; mod_t w2 = w; mod_t w3 = w; for (uint j = 0; j < m4; ++j) { for (uint r = 0, i = j + r; r < n; r += m, i += m) { mod_t a0 = f[i + m4 * 0]; mod_t a2 = f[i + m4 * 1] * w2; mod_t a1 = f[i + m4 * 2] * w; mod_t a3 = f[i + m4 * 3] * w3; mod_t t02 = a0 + a2; mod_t t13 = a1 + a3; f[i + m4 * 0] = t02 + t13; f[i + m4 * 2] = t02 - t13; t02 = a0 - a2; t13 = a1 - a3; t13 *= imag; f[i + m4 * 1] = t02 + t13; f[i + m4 * 3] = t02 - t13; } w *= dw; w2 = w * w; w3 = w * w2; } } } }; const uint64 P1 = 0x3f91300000000001ull; const uint64 P2 = 0x3f93300000000001ull; const uint64 P1_INV_MOD_P2 = 0x19D3CB8000001FCAull; typedef Mod64<P1, 0x1941B388165C78EBull, 44, 0x3f912fffffffffffull, 0x0298CD3E4612D42Aull> mod64_t1; typedef Mod64<P2, 0x394ba9c52fec1825ull, 44, 0x3f932fffffffffffull, 0x30da2c2b74e27a1full> mod64_t2; // ------------------------------------------------------------------------------ const uint mod = 1e9 + 7; const uint64 lim = (0xFFFFFFFFFFFFFFFFull / mod - mod) * mod; const uint MAX_PC = 300; const uint NTT_BUFF_SIZE = (1 << 14) + 1; uint dp[MAX_PC + 1][MAX_PC * 13 + 1]; uint poly_P[MAX_PC * 13 + 1]; uint poly_C[MAX_PC * 12 + 1]; uint poly[MAX_PC * 25 + 1]; uint poly_inv[MAX_PC * 25 + 100]; uint sums[MAX_PC * 25 + 1]; uint counts[MAX_PC * 25 + 1]; uint rem[2 * MAX_PC * 25 + 100]; uint64 temp[2 * MAX_PC * 25 + 100]; mod64_t1 A1[NTT_BUFF_SIZE], B1[NTT_BUFF_SIZE]; mod64_t2 A2[NTT_BUFF_SIZE], B2[NTT_BUFF_SIZE]; void _poly_mul_restore(const uint size, mod64_t1* a1, mod64_t2* a2, uint* res) { for(uint i = 0; i < size; ++i) { uint64 x1 = a1[i].get_value(); uint64 x2 = a2[i].get_value(); if(x1 != x2) { uint64 dx = x2 - x1; if(int64(dx) < 0) { dx += P2; // P1 < P2 } dx = (mod64_t2(dx) * mod64_t2(P1_INV_MOD_P2)).get_value(); x1 = (x1 + dx % mod * (P1 % mod)); } res[i] = x1 % mod; } } uint poly_square(const uint* p1, uint s1, uint* res) { for (uint i = 0; i < s1; ++i) { A1[i] = mod64_t1(p1[i]); A2[i] = mod64_t2(p1[i]); } NTT<mod64_t1>::auto_convolute(A1, s1); NTT<mod64_t2>::auto_convolute(A2, s1); _poly_mul_restore(s1 * 2 - 1, A1, A2, res); return s1 * 2 - 1; } uint poly_mul(const uint* p1, uint s1, const uint* p2, uint s2, uint* res) { for (uint i = 0; i < s1; ++i) { A1[i] = mod64_t1(p1[i]); A2[i] = mod64_t2(p1[i]); } for (uint i = 0; i < s2; ++i) { B1[i] = mod64_t1(p2[i]); B2[i] = mod64_t2(p2[i]); } NTT<mod64_t1>::convolute(A1, s1, B1, s2); NTT<mod64_t2>::convolute(A2, s1, B2, s2); _poly_mul_restore(s1 + s2 - 1, A1, A2, res); return s1 + s2 - 1; } uint _poly_rem(const uint* p1, uint s1, const uint* p2, uint s2, uint ofs, uint* ret) { if (s1 < s2) { copy(p1, p1 + s1, ret); return s1; } copy(p1, p1 + s1, temp); uint dif = s1 - s2 + 1; for (uint i = 0; i < dif; ++i) { uint64 coef = temp[i] % mod; if (!coef) { continue; } coef = mod - coef; for (uint j = ofs; j < s2; ++j) { temp[i + j] += coef * p2[j]; if (temp[i + j] >= lim) { temp[i + j] -= lim; } } } for (uint i = dif; i < s1; ++i) { ret[i - dif] = temp[i] % mod; } return s2 - 1; } uint poly_tmp[2 * MAX_PC * 25 + 100]; uint poly_tmp2[2 * MAX_PC * 25 + 100]; uint poly_q[2 * MAX_PC * 25 + 100]; uint poly_div(const uint* poly1, uint s1, const uint* , uint s2, const uint *inv, uint inv_size, uint *res) { assert(s1 >= s2); uint needed_size = s1 - s2 + 1; assert(inv_size >= needed_size); poly_mul(poly1, needed_size, inv, needed_size, res); return needed_size; } uint poly_rem(const uint* poly1, uint s1, const uint* poly2, uint s2, uint* inv, uint inv_size, uint* res) { if (s1 < s2) { copy(poly1, poly1 + s1, res); return s1; } uint q_size = poly_div(poly1, s1, poly2, s2, inv, inv_size, poly_q); poly_mul(poly_q, q_size, poly2, s2, poly_q); uint ofs = s1 - s2 + 1; for (uint i = 0; i < s2 - 1; ++i) { res[i] = (poly1[ofs + i] + mod - poly_q[ofs + i]) % mod; } return s2 - 1; } uint poly_inverse(const uint *poly, uint size, uint* inv, uint needed_size) { uint degs[100]; uint deg = needed_size - 1; uint deg_pos = 0; while (deg) { degs[deg_pos++] = deg; deg >>= 1; } if (size < needed_size) { copy(poly, poly + size, poly_tmp); fill(poly_tmp + size, poly_tmp + needed_size, 0); } else { copy(poly, poly + needed_size, poly_tmp); } uint inv_size = 1; inv[0] = 1; while (deg_pos) { uint t = degs[--deg_pos]; uint added = t + 1 - inv_size; poly_mul(poly_tmp, t + 1, inv, inv_size, poly_tmp2); poly_mul(poly_tmp2 + inv_size, added, inv, added, poly_tmp2); for (uint i = inv_size; i < inv_size + added; ++i) { inv[i] = (mod - poly_tmp2[i - inv_size] % mod) % mod; } inv_size += added; } return inv_size; } uint poly_power_rem(uint64 e, const uint* divisor, uint div_size, uint* ret) { uint64 mask = 1; while (mask <= e) { mask <<= 1; } mask >>= 1; uint ret_size = 0; ret[ret_size++] = 1; const uint inv_size = poly_inverse(divisor, div_size, poly_inv, div_size); while (mask) { if (e & mask) { ret[ret_size++] = 0; } mask >>= 1; if (mask == 0) { break; } ret_size = poly_square(ret, ret_size, ret); ret_size = poly_rem(ret, ret_size, divisor, div_size, poly_inv, inv_size, ret); } if (ret_size >= div_size) { ret_size = poly_rem(ret, ret_size, divisor, div_size, poly_inv, inv_size, ret); } return ret_size; } void conv(uint* poly, uint size) { poly[0] = 1; for (uint i = 1; i < size; ++i) { if (poly[i]) { poly[i] = mod - poly[i]; } } } inline void add_mod(uint& a, uint b) { a += b; if (a >= mod) { a -= mod; } } void init_poly(const uint* dice, const uint T, uint* result) { for (uint i = 1; i <= T; ++i) { fill(dp[i], dp[i] + dice[5] * i + 1, 0); } dp[0][0] = 1; for (uint di = 0; di < 6; ++di) { const uint d = dice[di]; for (uint t = 0; t < T; ++t) { for (uint i = t * dice[0]; i <= t * dice[di]; ++i) { if (dp[t][i]) { add_mod(dp[t + 1][i + d], dp[t][i]); } } } } for (uint i = 0; i <= dice[5] * T; ++i) { result[i] = dp[T][i]; } } void solve() { const uint Ps[] = {2, 3, 5, 7, 11, 13}; const uint Cs[] = {4, 6, 8, 9, 10, 12}; uint64 N; uint P, C; while (~scanf("%llu %u %u", &N, &P, &C)) { init_poly(Ps, P, poly_P); init_poly(Cs, C, poly_C); uint size_p = 13 * P + 1; uint size_c = 12 * C + 1; poly_mul(poly_P, size_p, poly_C, size_c, poly); const uint poly_size = 13 * P + 12 * C + 1; conv(poly, poly_size); uint rem_size = poly_power_rem(N + poly_size - 2, poly, poly_size, rem); uint64 ans = 0; for (uint i = 0; i < rem_size; ++i) { ans += rem[i]; } printf("%llu\n", ans % mod); } } int main() { solve(); return 0; }