結果

問題 No.214 素数サイコロと合成数サイコロ (3-Medium)
ユーザー Min_25Min_25
提出日時 2015-05-23 17:05:18
言語 C++11
(gcc 11.4.0)
結果
AC  
実行時間 74 ms / 3,000 ms
コード長 4,664 bytes
コンパイル時間 398 ms
コンパイル使用メモリ 55,876 KB
実行使用メモリ 4,380 KB
最終ジャッジ日時 2023-09-20 10:53:45
合計ジャッジ時間 1,093 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 72 ms
4,376 KB
testcase_01 AC 68 ms
4,380 KB
testcase_02 AC 74 ms
4,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <cassert>
#include <iostream>

using namespace std;

typedef unsigned uint;
typedef unsigned long long uint64;
typedef long long int64;


uint dp[301][3901];

uint poly_P[4000];
uint poly_C[4000];
uint poly[8000];
uint poly_inv[8000];
uint counts[8000];
uint sums[8000];

uint rem[16000];
uint64 temp[16000];

const uint mod = 1000000007;
const uint64 lim = (0x7FFFFFFFFFFFFFFFull / mod - mod) * mod;

inline void add_mod(uint& a, uint b) {
  a += b;
  if (a >= mod) {
    a -= mod;
  }
}

void init_poly(const uint* dice, const uint T, uint* result) {
  for (uint i = 1; i <= T; ++i) {
    fill(dp[i], dp[i] + dice[5] * i + 1, 0);
  }
  dp[0][0] = 1;
  for (uint di = 0; di < 6; ++di) {
    const uint d = dice[di];
    for (uint t = 0; t < T; ++t) {
      for (uint i = t * dice[0]; i <= t * dice[di]; ++i) {
        if (dp[t][i]) {
          add_mod(dp[t + 1][i + d], dp[t][i]);
        }
      }
    }
  }
  for (uint i = 0; i <= dice[5] * T; ++i) {
    result[i] = dp[T][i];
  }
}

uint64 poly64[8000];
uint64 mods[2] __attribute__((aligned(16)));
uint64 coefs[2] __attribute__((aligned(16)));

uint64 naive[8000];

void sub_mod_sse2(uint64* src, uint size, uint64* dest, uint64 coef) {
  coefs[0] = coefs[1] = coef;
  __asm__(
    "movdqa (%0), %%xmm6\n\t"
    "movdqa (%1), %%xmm7\n\t"
    :: "r"(coefs), "r"(mods)
  );

  for (uint i = 0; i < size; i += 2) {
    __asm__(
      "movdqu (%0), %%xmm0\n\t"
      "pxor %%xmm2, %%xmm2\n\t"
      "movdqu (%1), %%xmm1\n\t"
      "pmuludq %%xmm6, %%xmm0\n\t"
      "psubq %%xmm0, %%xmm1\n\t"
      "pcmpgtd %%xmm1, %%xmm2\n\t"
      "pand %%xmm7, %%xmm2\n\t"
      "paddd %%xmm2, %%xmm1\n\t"
      "movdqu %%xmm1, (%1)\n\t"
      :: "S"(src + i), "D"(dest + i)
      :
    );
  }
}

uint poly_square(uint* p1, uint s1, uint* res) {
  uint s = 2 * s1 - 1;
  fill(temp, temp + s, 0);
  for (uint i = 0; i < s1; ++i) {
    temp[2 * i] = uint64(p1[i]) * p1[i];
  }

  copy(p1, p1 + s1, poly64);
  poly64[s1] = 0;
  for (uint i = 0; i < s1; ++i) {
    uint coef = poly64[i];
    if (!coef) {
      continue;
    }
    coef = (coef << 1) % mod;
    coef = mod - coef;
    sub_mod_sse2(poly64 + i + 1, s1 - i - 1, temp + 2 * i + 1, coef);
  }
  for (uint i = 0; i < s; ++i) {
    res[i] = temp[i] % mod;
  }
  return s;
}

void poly_mul(uint* p1, uint s1, uint* p2, uint s2, uint* res) {
  uint s = s1 + s2 - 1;
  fill(temp, temp + s, 0);

  copy(p2, p2 + s2, poly64);
  poly64[s2] = 0;
  for (uint i = 0; i < s1; ++i) {
    uint64 coef = p1[i];
    if (!coef) {
      continue;
    }
    coef = mod - coef;
    sub_mod_sse2(poly64, s2, temp + i, coef);
  }
  for (uint i = 0; i < s; ++i) {
    res[i] = temp[i] % mod;
  }
}

uint poly_rem(uint* p1, uint s1, const uint* p2, uint s2, uint ofs, uint* ret) {
  if (s1 < s2) {
    copy(p1, p1 + s1, ret);
    return s1;
  }
  copy(p1, p1 + s1, temp);
  copy(p2, p2 + s2, poly64);

  poly64[s2] = 0;
  uint dif = s1 - s2 + 1;
  for (uint i = 0; i < dif; ++i) {
    uint coef = temp[i] % mod;
    if (!coef) {
      continue;
    }
    sub_mod_sse2(poly64 + ofs, s2 - ofs, temp + i + ofs, coef);
  }
  for (uint i = dif; i < s1; ++i) {
    ret[i - dif] = temp[i] % mod;
  }
  return s2 - 1;
}

uint poly_power_rem(uint64 e, const uint* divisor, uint div_size, uint ofs, uint* ret) {
  uint64 mask = 1;
  while (mask <= e) {
    mask <<= 1;
  }
  mask >>= 1;
  ret[0] = 1;
  uint ret_size = 1;

  while (mask) {
    if (e & mask) {
      ret[ret_size++] = 0;
    }
    mask >>= 1;
    if (mask == 0) {
      break;
    }
    ret_size = poly_square(ret, ret_size, ret);
    ret_size = poly_rem(ret, ret_size, divisor, div_size, ofs, ret);
  }
  if (ret_size >= div_size) {
    ret_size = poly_rem(ret, ret_size, divisor, div_size, ofs, ret);
  }
  return ret_size;
}

void conv(uint* poly, uint size) {
  poly[0] = 1;
  for (uint i = 1; i < size; ++i) {
    if (poly[i]) {
      poly[i] = mod - poly[i];
    }
  }
}

void init_mod() {
  mods[0] = mods[1] = uint64(mod) << 32;  
}

int main() {
  init_mod();

  const uint Ps[] = {2, 3, 5, 7, 11, 13};
  const uint Cs[] = {4, 6, 8, 9, 10, 12};

  uint64 N;
  uint P, C;
  
  while (~scanf("%llu %u %u", &N, &P, &C)) {
    init_poly(Ps, P, poly_P);
    init_poly(Cs, C, poly_C);

    uint size_p = 13 * P + 1;
    uint size_c = 12 * C + 1;
    poly_mul(poly_P, size_p, poly_C, size_c, poly);

    const uint poly_size = 13 * P + 12 * C + 1;
    conv(poly, poly_size);
    uint rem_size = poly_power_rem(N + poly_size - 2, poly, poly_size, 2 * P + 4 * C, rem);

    uint64 ans = 0;
    for (uint i = 0; i < rem_size; ++i) {
      ans += rem[i];
    }
    printf("%llu\n", ans % mod);
  }
  return 0;
}
0