
問題 No.215 素数サイコロと合成数サイコロ (3-Hard)
ユーザー yosupotyosupot
提出日時 2017-02-07 00:53:24
言語 C++11
(gcc 11.4.0)
実行時間 -
コード長 15,172 bytes
コンパイル時間 3,018 ms
コンパイル使用メモリ 209,356 KB
実行使用メモリ 816,512 KB
最終ジャッジ日時 2024-06-06 16:30:53
合計ジャッジ時間 7,608 ms
judge1 / judge4


入力 結果 実行時間
testcase_00 MLE -
testcase_01 -- -
main.cpp: In function ‘void fft4(bool, int, R4*, R4*)’:
main.cpp:125:19: warning: ignoring return value of ‘int posix_memalign(void**, size_t, size_t)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  125 |     posix_memalign((void **)&ax, 32, sizeof(R4)*N);
      |     ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
main.cpp:126:19: warning: ignoring return value of ‘int posix_memalign(void**, size_t, size_t)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  126 |     posix_memalign((void **)&ay, 32, sizeof(R4)*N);
      |     ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
main.cpp:127:19: warning: ignoring return value of ‘int posix_memalign(void**, size_t, size_t)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  127 |     posix_memalign((void **)&bx, 32, sizeof(R4)*N);
      |     ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
main.cpp:128:19: warning: ignoring return value of ‘int posix_memalign(void**, size_t, size_t)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  128 |     posix_memalign((void **)&by, 32, sizeof(R4)*N);
      |     ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
main.cpp: In function ‘std::vector<_RealType> multiply(std::vector<_RealType>, std::vector<_RealType>) [with Mint = ModInt<1000000007>]’:
main.cpp:169:23: warning: ignoring return value of ‘int posix_memalign(void**, size_t, size_t)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  169 |         posix_memalign((void **)&ax, 32, sizeof(R4)*N); memset(ax, 0, sizeof(R4)*N);
      |         ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
main.cpp:170:23: warning: ignoring return value of ‘int posix_memalign(void**, size_t, size_t)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  170 |         posix_memalign((void **)&ay, 32, sizeof(R4)*N); memset(ay, 0, sizeof(R4)*N);
      |         ~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


diff #

#pragma GCC target ("avx")

#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cassert>
#include <algorithm>
#include <numeric>
#include <random>
#include <vector>
#include <array>
#include <bitset>
#include <queue>
#include <set>
#include <unordered_set>
#include <map>
#include <unordered_map>
#include <complex>
#include <immintrin.h>
#define ALIGN __attribute__((aligned(32)))

using namespace std;
using uint = unsigned int;
using ll = long long;
using ull = unsigned long long;
template<class T> using V = vector<T>;
template<class T> using VV = V<V<T>>;
constexpr ll TEN(int n) { return (n==0) ? 1 : 10*TEN(n-1); }
int bsr(int x) { return 31 - __builtin_clz(x); }
int bsr(ll x) { return 63 - __builtin_clzll(x); }
int bsf(int x) { return __builtin_ctz(x); }
int bsf(ll x) { return __builtin_ctzll(x); }

template<class T>
T pow(T x, ll n, T r = 1) {
    while (n) {
        if (n & 1) r *= x;
        x *= x;
        n >>= 1;
    return r;

template<uint MD>
struct ModInt {
    uint v;
    ModInt() : v{0} {}
    ModInt(ll v) : v{normS(v%MD+MD)} {}
    explicit operator bool() const {return v != 0;}
    static uint normS(const uint &x) {return (x<MD)?x:x-MD;};
    static ModInt make(const uint &x) {ModInt m; m.v = x; return m;}
    static ModInt inv(const ModInt &x) {return pow(ModInt(x), MD-2);} 
    ModInt operator+(const ModInt &r) const {return make(normS(v+r.v));}
    ModInt operator-(const ModInt &r) const {return make(normS(v+MD-r.v));}
    ModInt operator*(const ModInt &r) const {return make((ull)v*r.v%MD);}
    ModInt operator/(const ModInt &r) const {return *this*inv(r);}
    ModInt& operator+=(const ModInt &r) {return *this=*this+r;}
    ModInt& operator-=(const ModInt &r) {return *this=*this-r;}
    ModInt& operator*=(const ModInt &r) {return *this=*this*r;}
    ModInt& operator/=(const ModInt &r) {return *this=*this/r;}
template<uint MD> string to_string(ModInt<MD> m) {return to_string(m.v);}
using Mint = ModInt<TEN(9)+7>;

using R = double;
const R PI = 4*atan(R(1));
struct Pc {
    R x, y;
    Pc() : x(0), y(0) {}
    Pc(R x, R y) : x(x), y(y) {}
    Pc operator+(const Pc &r) const {return Pc(x+r.x, y+r.y);}
    Pc operator-(const Pc &r) const {return Pc(x-r.x, y-r.y);}
    Pc operator*(const Pc &r) const {return Pc(x*r.x-y*r.y, x*r.y+y*r.x);}
    Pc operator*(const R &r) const {return Pc(x*r, y*r);}
    Pc& operator+=(const Pc &r) {return *this=*this+r;}
    Pc& operator-=(const Pc &r) {return *this=*this-r;}
    Pc& operator*=(const Pc &r) {return *this=*this*r;}   
    Pc& operator*=(const R &r) {return *this=*this*r;}
    static Pc polar(R r, R th) {return Pc(cos(th)*r, sin(th)*r);}

void fft(bool type, vector<Pc> &c) {
    static vector<Pc> buf[30];
    int N = int(c.size());
    int s = bsr(N);
    assert(1<<s == N);
    if (!buf[s].size()) {
        buf[s] = vector<Pc>(N);
        for (int i = 0; i < N; i++) {
            buf[s][i] = Pc::polar(1, i*2*PI/N);
    vector<Pc> a = c, b(N);
    for (int i = 1; i <= s; i++) {
        int W = 1<<(s-i); //変更後の幅W
        for (int y = 0; y < N/2; y += W) {
            Pc now = buf[s][y]; if (type) now.y *= -1;
            for (int x = 0; x < W; x++) {
                auto l =       a[y<<1 | x];
                auto r = now * a[y<<1 | x | W];
                b[y | x]        = l+r;
                b[y | x | N>>1] = l-r;
        swap(a, b);            
    c = a;

using R4 = __m256d;
void fft4(bool type, int N, R4 cx[], R4 cy[]) {
    static vector<Pc> buf[30];
    int s = bsr(N);
    assert(1<<s == N);
    if (!buf[s].size()) {
        buf[s] = vector<Pc>(N);
        for (int i = 0; i < N; i++) {
            buf[s][i] = Pc::polar(1, i*2*PI/N);
    R4 *ax, *ay, *bx, *by;
    posix_memalign((void **)&ax, 32, sizeof(R4)*N);
    posix_memalign((void **)&ay, 32, sizeof(R4)*N);
    posix_memalign((void **)&bx, 32, sizeof(R4)*N);
    posix_memalign((void **)&by, 32, sizeof(R4)*N);
    memcpy(ax, cx, sizeof(R4)*N);
    memcpy(ay, cy, sizeof(R4)*N);
    for (int i = 1; i <= s; i++) {
        int W = 1<<(s-i); //変更後の幅W
        for (int y = 0; y < N/2; y += W) {
            Pc now = buf[s][y]; if (type) now.y *= -1;
            R4 nowx = _mm256_broadcast_sd(&now.x);
            R4 nowy = _mm256_broadcast_sd(&now.y);
            for (int x = 0; x < W; x++) {
/*                auto l =       a[y<<1 | x];
                auto r = now * a[y<<1 | x | W];
                b[y | x]        = l+r;
                b[y | x | N>>1] = l-r;*/
                R4 lx = ax[y<<1 | x];
                R4 ly = ay[y<<1 | x];
                R4 bufx = ax[y<<1 | x | W];
                R4 bufy = ay[y<<1 | x | W];
                R4 rx = _mm256_sub_pd(_mm256_mul_pd(bufx, nowx), _mm256_mul_pd(bufy, nowy));
                R4 ry = _mm256_add_pd(_mm256_mul_pd(bufx, nowy), _mm256_mul_pd(bufy, nowx));
                bx[y | x] = _mm256_add_pd(lx, rx);
                by[y | x] = _mm256_add_pd(ly, ry);
                bx[y | x | N>>1] = _mm256_sub_pd(lx, rx);
                by[y | x | N>>1] = _mm256_sub_pd(ly, ry);
        swap(ax, bx);
        swap(ay, by);
    memcpy(cx, ax, sizeof(R4)*N);
    memcpy(cy, ay, sizeof(R4)*N);

template<class Mint>
vector<Mint> multiply(vector<Mint> x, vector<Mint> y) {
    constexpr int B = 3, SHIFT = 10;
    int S = x.size()+y.size()-1;
    int N = 2<<bsr(S-1);
    R4 *ax, *ay, *bx, *by;
        posix_memalign((void **)&ax, 32, sizeof(R4)*N); memset(ax, 0, sizeof(R4)*N);
        posix_memalign((void **)&ay, 32, sizeof(R4)*N); memset(ay, 0, sizeof(R4)*N);
        for (int i = 0; i < int(x.size()); i++) {
            int mask = (1<<SHIFT)-1;
            ax[i] = _mm256_set_pd(0, (x[i].v >> (2*SHIFT)) & mask, (x[i].v >> SHIFT) & mask, x[i].v & mask);
            ay[i] = _mm256_setzero_pd();
        fft4(false, N, ax, ay);
        posix_memalign((void **)&bx, 32, sizeof(R4)*N); memset(bx, 0, sizeof(R4)*N);
        posix_memalign((void **)&by, 32, sizeof(R4)*N); memset(by, 0, sizeof(R4)*N);
        for (int i = 0; i < int(y.size()); i++) {
            int mask = (1<<SHIFT)-1;
            bx[i] = _mm256_set_pd(0, (y[i].v >> (2*SHIFT)) & mask, (y[i].v >> SHIFT) & mask, y[i].v & mask);
            by[i] = _mm256_setzero_pd();
        fft4(false, N, bx, by);
    vector<Mint> z(S);
        R4 *cx, *cy;
        posix_memalign((void **)&cx, 32, sizeof(R4)*N); memset(cx, 0, sizeof(R4)*N);
        posix_memalign((void **)&cy, 32, sizeof(R4)*N); memset(cy, 0, sizeof(R4)*N);
        for (int i = 0; i < N; i++) {
            R nax[4] ALIGN, nay[4] ALIGN, nbx[4] ALIGN, nby[4] ALIGN, ncx[4] ALIGN = {}, ncy[4] ALIGN = {};
            _mm256_store_pd(nax, ax[i]);
            _mm256_store_pd(nay, ay[i]);
            _mm256_store_pd(nbx, bx[i]);
            _mm256_store_pd(nby, by[i]);
            for (int xf = 0; xf < 3; xf++) {
                for (int yf = 0; yf < 3; yf++) {
                    int zf = xf+yf;
                    if (zf == 4) continue;
                    ncx[zf] += nax[xf]*nbx[yf] - nay[xf]*nby[yf];
                    ncy[zf] += nax[xf]*nby[yf] + nay[xf]*nbx[yf];
            //zf = 4
            ncy[0] += nax[2]*nbx[2] - nay[2]*nby[2];
            ncx[0] -= nax[2]*nby[2] + nay[2]*nbx[2];
            cx[i] = _mm256_add_pd(cx[i], _mm256_load_pd(ncx));
            cy[i] = _mm256_add_pd(cy[i], _mm256_load_pd(ncy));
        fft4(true, N, cx, cy);
        for (int i = 0; i < S; i++) {
            R ncx[4] ALIGN;
            _mm256_store_pd(ncx, cx[i]);
            Mint base = 1;
            for (int fe = 0; fe < 4; fe++) {
                ncx[fe] *= 1.0/N; //todo: optimize
                z[i] += Mint(ll(round(ncx[fe]))) * base;
                base *= 1<<SHIFT;
            _mm256_store_pd(ncx, cy[i]);
            ncx[0] *= 1.0/N;
            z[i] += Mint(ll(round(ncx[0]))) * base;
        free(cx); free(cy);
    free(ax); free(ay); free(bx); free(by);
    return z;

template<class Mint>
vector<Mint> multiply2(vector<Mint> x, vector<Mint> y) {
    constexpr int B = 2, SHIFT = 15;
    int S = x.size()+y.size()-1;
    int N = 2<<bsr(S-1);
    vector<Pc> a[B], b[B];
    for (int i = 0; i < int(x.size()); i++) {
        a[i] = Pc(x[i].v / (1<<15), x[i].v % (1<<15));
    for (int i = 0; i < int(y.size()); i++) {
        b[i] = Pc(y[i].v / (1<<15), y[i].v % (1<<15));
    fft(false, a); fft(false, b);
    for (int fe = 0; fe < B; fe++) {
        a[fe] = vector<Pc>(N);
        b[fe] = vector<Pc>(N);
        for (int i = 0; i < int(x.size()); i++) {
            a[fe][i] = Pc((x[i].v >> (fe*SHIFT)) & ((1<<SHIFT)-1), 0);
        for (int i = 0; i < int(y.size()); i++) {
            b[fe][i] = Pc((y[i].v >> (fe*SHIFT)) & ((1<<SHIFT)-1), 0);
        fft(false, a[fe]);
        fft(false, b[fe]);
    vector<Mint> z(S);
    vector<Pc> c(N);
    Mint base = 1;
    for (int fe = 0; fe <= (B-1)*2; fe++) {
        fill_n(begin(c), N, Pc(0, 0));
        for (int xf = max(fe-(B-1), 0); xf <= min(B-1, fe); xf++) {
            int yf = fe-xf;
            for (int i = 0; i < N; i++) {
                c[i] += a[xf][i]*b[yf][i];
        fft(true, c);
        for (int i = 0; i < S; i++) {
            c[i] *= 1.0/N;
            z[i] += Mint(ll(round(c[i].x))) * base;
        base *= 1<<SHIFT;
    return z;
template<class D>
struct Poly {
    V<D> v;
    int size() const {return int(v.size());}
    Poly(int N = 0) : v(V<D>(N)) {}
    Poly(const V<D> &v) : v(v) {shrink();}
    Poly& shrink() {while (v.size() && !v.back()) v.pop_back(); return *this;}
    D freq(int p) const { return (p < size()) ? v[p] : D(0); }

    Poly operator+(const Poly &r) const {
        int N = size(), M = r.size();
        V<D> res(max(N, M));
        for (int i = 0; i < max(N, M); i++) res[i] = freq(i)+r.freq(i);
        return Poly(res);
    Poly operator-(const Poly &r) const {
        int N = size(), M = r.size();
        V<D> res(max(N, M));
        for (int i = 0; i < max(N, M); i++) res[i] = freq(i)-r.freq(i);
        return Poly(res);
    Poly operator*(const Poly &r) const {
        int N = size(), M = r.size();
        if (min(N, M) == 0) return Poly();
        assert(N+M-1 >= 0);
        V<D> res = multiply(v, r.v);
        return Poly(res);
    Poly operator*(const D &r) const {
        V<D> res(size());
        for (int i = 0; i < size(); i++) res[i] = v[i]*r;
        return Poly(res);
    Poly& operator+=(const Poly &r) {return *this = *this+r;}
    Poly& operator-=(const Poly &r) {return *this = *this-r;}
    Poly& operator*=(const Poly &r) {return *this = *this*r;}
    Poly& operator*=(const D &r) {return *this = *this*r;}

    Poly operator<<(const int n) const {
        assert(n >= 0);
        V<D> res(size()+n);
        for (int i = 0; i < size(); i++) {
            res[i+n] = v[i];
        return Poly(res);
    Poly operator>>(const int n) const {
        assert(n >= 0);
        if (size() <= n) return Poly();
        V<D> res(size()-n);
        for (int i = n; i < size(); i++) {
            res[i-n] = v[i];
        return Poly(res);

    // x % y
    Poly rem(const Poly &y) const {
        return *this - y * div(y);
    Poly rem_inv(const Poly &y, const Poly &ny, int B) const {
        return *this - y * div_inv(ny, B);
    Poly div(const Poly &y) const {
        int B = max(size(), y.size());
        return div_inv(y.inv(B), B);
    Poly div_inv(const Poly &ny, int B) const {
        return (*this*ny)>>(B-1);
    // this * this.inv() = x^n + r(x) (size())
    Poly strip(int n) const {
        V<D> res = v;
        res.resize(min(n, size()));
        return Poly(res);
    Poly rev(int n = -1) const {
        V<D> res = v;
        if (n != -1) res.resize(n);
        reverse(begin(res), end(res));
        return Poly(res);
    // f * f.inv() = x^B + r(x) (B >= n)
    Poly inv(int n) const {
        int N = size();
        assert(N >= 1);
        assert(n >= N-1);
        Poly c = rev();
        Poly d = Poly(V<D>({D(1)/c.freq(0)}));
        int i;
        for (i = 1; i+N-2 < n; i *= 2) {
            auto u = V<D>({2});
            d = (d * (Poly(V<D>{2})-c*d)).strip(2*i);
        return d.rev(n+1-N);
template<class D>
string to_string(const Poly<D> &p) {
    if (p.size() == 0) return "0";
    string s = "";
    for (int i = 0; i < p.size(); i++) {
        if (p.v[i]) {
            s += to_string(p.v[i])+"x^"+to_string(i);
            if (i != p.size()-1) s += "+";
    return s;
// x^n % mod
template<class D>
Poly<D> nth_mod(ll n, const Poly<D> &mod) {
    int B = mod.size() * 2 - 1;
    Poly<D> mod_inv = mod.inv(B);
    Poly<D> p = V<D>{Mint(1)};
    int m = (!n) ? -1 : bsr(n);
    for (int i = m; i >= 0; i--) {
        if (n & (1LL<<i)) {
            // += 1
            p = (p<<1).rem_inv(mod, mod_inv, B);
        if (i) {
            // *= 2
            p = (p*p).rem_inv(mod, mod_inv, B);
    return p;

const int MD = 610 * 13;
const int MC = 310;
const V<int> pd = {2, 3, 5, 7, 11, 13};
const V<int> cd = {4, 6, 8, 9, 10, 12};

int main() {
    ll n; int x, y;
    cin >> n >> x >> y;
    V<Mint> co(MD+1); co[0] = 1;
        int c = x;
        Mint buf[MC][MD];
        buf[0][0] = 1;
        for (int d: pd) {
            for (int nc = 1; nc < MC; nc++) {
                for (int nw = d; nw < MD; nw++) {
                    buf[nc][nw] += buf[nc-1][nw-d];
        V<Mint> co2(MD+1);
        for (int i = 0; i < MD; i++) {
            for (int j = 0; j < MD-i; j++) {
                co2[i+j] += co[i]*buf[c][j];
        co = co2;
        int c = y;
        Mint buf[MC][MD];
        buf[0][0] = 1;
        for (int d: cd) {
            for (int nc = 1; nc < MC; nc++) {
                for (int nw = d; nw < MD; nw++) {
                    buf[nc][nw] += buf[nc-1][nw-d];
        V<Mint> co2(MD+1);
        for (int i = 0; i < MD; i++) {
            for (int j = 0; j < MD-i; j++) {
                co2[i+j] += co[i]*buf[c][j];
        co = co2;
    co[0] = -1;
    auto rco = co;
    reverse(begin(rco), end(rco));
    auto pol = nth_mod(n, Poly<Mint>(rco));
    V<Mint> buf(MD); buf[0] = 1;
    V<Mint> sm(MD);
    for (int i = 0; i < MD; i++) {
        //v[i] * f
        for (int j = 0; j < MD; j++) {
            sm[j] += pol.freq(i)*buf[j];
        for (int j = 1; j < MD; j++) {
            buf[j] += buf[0]*co[j];
        for (int j = 0; j < MD-1; j++) {
            buf[j] = buf[j+1];
        buf[MD-1] = 0;
    Mint ans = 0;
    for (int i = 0; i < MD; i++) {
        ans += sm[i];
    cout << ans.v << endl;
    return 0;