結果
| 問題 |
No.8030 ミラー・ラビン素数判定法のテスト
|
| ユーザー |
👑 |
| 提出日時 | 2022-08-31 07:46:54 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 23 ms / 9,973 ms |
| コード長 | 13,259 bytes |
| コンパイル時間 | 389 ms |
| コンパイル使用メモリ | 37,512 KB |
| 最終ジャッジ日時 | 2025-02-07 00:13:46 |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 10 |
コンパイルメッセージ
main.cpp: In function ‘int main(int, char**)’:
main.cpp:308:10: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
308 | scanf("%d", &n);
| ~~~~~^~~~~~~~~~
main.cpp:310:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
310 | scanf("%llu", &x);
| ~~~~~^~~~~~~~~~~~
ソースコード
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,bmi2,lzcnt,tune=native")
//#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
//#pragma GCC target ("sse4")
#pragma GCC optimize("O3")
//#pragma GCC optimize ("tree-vectorize")
//#pragma GCC optimize("unroll-loops")
#ifndef NDEBUG
#define NDEBUG
#endif
#include <cassert>
#include <ctime>
#include <cstdio>
#include <cstdbool>
#include <cstdint>
class U64Mont {
private:
static uint64_t _ni(uint64_t n) { // n * ni == 1 (mod 2**64)
// // n is odd number, n = 2*k+1, n >= 1, n < 2**64, k is non-negative integer, k >= 0, k < 2**63
// ni0 := n; // = 2*k+1 = (1+(2**2)*((k*(k+1))**1))/(2*k+1)
uint64_t ni = n;
// ni1 := ni0 * (2 - (n * ni0)); // = (1-(2**4)*((k*(k+1))**2))/(2*k+1)
// ni2 := ni1 * (2 - (n * ni1)); // = (1-(2**8)*((k*(k+1))**4))/(2*k+1)
// ni3 := ni2 * (2 - (n * ni2)); // = (1-(2**16)*((k*(k+1))**8))/(2*k+1)
// ni4 := ni3 * (2 - (n * ni3)); // = (1-(2**32)*((k*(k+1))**16))/(2*k+1)
// ni5 := ni4 * (2 - (n * ni4)); // = (1-(2**64)*((k*(k+1))**32))/(2*k+1)
// // (n * ni5) mod 2**64 = ((2*k+1) * ni5) mod 2**64 = 1 mod 2**64
for (int i = 0; i < 5; ++i) {
ni = ni * (2 - n * ni);
}
assert(n * ni == 1); // n * ni == 1 (mod 2**64)
return ni;
}
static uint64_t _n1(uint64_t n) { // == n - 1
return n - 1;
}
static uint64_t _nh(uint64_t n) { // == (n + 1) / 2
return (n >> 1) + 1;
}
static uint64_t _r(uint64_t n) { // == 2**64 (mod n)
return (-n) % n;
}
static uint64_t _rn(uint64_t n) { // == -1 * (2**64) (mod n)
return n - _r(n);
}
static uint64_t _r2(uint64_t n) { // == 2**128 (mod n)
return (uint64_t)((-((__uint128_t)n)) % ((__uint128_t)n));
}
static uint32_t _k(uint64_t n) { // == trailing_zeros(n - 1)
// https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html#Other-Builtins
return __builtin_ctzll(_n1(n));
}
static uint64_t _d(uint64_t n) { // == (n - 1) >> trailing_zeros(n - 1) // n == 2**k * d + 1
return _n1(n) >> _k(n);
}
public:
const uint64_t n; // == n
const uint64_t ni; // n * ni == 1 (mod 2**64)
const uint64_t n1; // == n - 1
const uint64_t nh; // == (n + 1) / 2
const uint64_t r; // == 2**64 (mod n)
const uint64_t rn; // == -1 * (2**64) (mod n)
const uint64_t r2; // == 2**128 (mod n)
const uint64_t d; // == (n - 1) >> trailing_zeros(n - 1) // n == 2**k * d + 1
const uint32_t k; // == trailing_zeros(n - 1)
U64Mont(uint64_t n)
: n(n), ni(_ni(n)), n1(_n1(n)), nh(_nh(n)), r(_r(n)), rn(_rn(n)), r2(_r2(n)), d(_d(n)), k(_k(n))
{ assert((n & 1) == 1); }
uint64_t add(uint64_t a, uint64_t b) const {
// add(a, b) == a + b (mod n)
assert(a < n);
assert(b < n);
unsigned long long t, u;
// https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
bool fa = __builtin_uaddll_overflow(a, b, &t);
bool fs = __builtin_usubll_overflow(t, n, &u);
return (fa ? u : (fs ? t : u));
}
uint64_t sub(uint64_t a, uint64_t b) const {
// sub(a, b) == a - b (mod n)
assert(a < n);
assert(b < n);
unsigned long long t;
// https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
bool f = __builtin_usubll_overflow(a, b, &t);
return (f ? t + n : t);
}
uint64_t div2(uint64_t ar) const {
// div2(ar) == ar / 2 (mod n)
assert(ar < n);
return ((ar >> 1) + ((ar & 1) == 0 ? 0 : nh));
}
uint64_t mrmul(uint64_t ar, uint64_t br) const {
// mrmul(ar, br) == (ar * br) / r (mod n)
// R == 2**64
// gcd(N, R) == 1
// N * ni mod R == 1
// 0 <= ar < N < R
// 0 <= br < N < R
// T := ar * br
// t := floor(T / R) - floor(((T * ni mod R) * N) / R)
// if t < 0 then return t + N else return t
assert(ar < n);
assert(br < n);
__uint128_t t = ((__uint128_t)ar) * ((__uint128_t)br);
unsigned long long w;
// https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
bool f = __builtin_usubll_overflow((unsigned long long)(t >> 64), (unsigned long long)((((__uint128_t)(((uint64_t)t) * ni)) * ((__uint128_t)n)) >> 64), &w);
return (w + (f ? n : 0));
}
uint64_t mr(uint64_t ar) const {
// mr(ar) == ar / r (mod n)
// R == 2**64
// gcd(N, R) == 1
// N * ni mod R == 1
// 0 <= ar < N < R
// t := floor(ar / R) - floor(((ar * ni mod R) * N) / R)
// if t < 0 then return t + N else return t
assert(ar < n);
uint64_t v = (uint64_t)((((__uint128_t)(ar * ni)) * ((__uint128_t)n)) >> 64);
return ((v == 0) ? 0 : (n - v));
}
uint64_t ar(uint64_t a) const {
// ar(a) == a * r (mod n)
assert(a < n);
return mrmul(a, r2);
}
uint64_t pow(uint64_t ar, uint64_t b) const {
// pow(ar, b) == ((ar / r) ** b) * r (mod n)
assert(ar < n);
uint64_t tr = ((b & 1) == 0) ? r : ar;
for (b >>= 1; b != 0; b >>= 1) {
ar = mrmul(ar, ar);
if ((b & 1) != 0) { tr = mrmul(tr, ar); }
}
return tr;
}
};
// 64bit整数平方根(lz:ケチるループ回数*2+(0~1)、内部実装) -> (floor(sqrt(iv)), remain)
uint64_t isqrt64i(const uint64_t iv, uint64_t* remain, const uint32_t lz) {
constexpr __uint128_t _b = ((((__uint128_t)0x0000000000000000ULL) << 64) | ((__uint128_t)0x4000000000000000ULL));
constexpr __uint128_t _c = ((((__uint128_t)0xfffffffffffffffeULL) << 64) | ((__uint128_t)0x0000000000000000ULL));
constexpr __uint128_t _d = ((((__uint128_t)0x0000000000000001ULL) << 64) | ((__uint128_t)0x0000000000000000ULL));
constexpr __uint128_t _e = ((((__uint128_t)0x0000000000000000ULL) << 64) | ((__uint128_t)0xffffffffffffffffULL));
uint32_t n = (64 >> 1) - (lz >> 1);
uint32_t s = (lz >> 1) << 1;
uint32_t t = n << 1;
__uint128_t a = iv;
__uint128_t b = _b >> s;
__uint128_t c = _c >> s;
__uint128_t d = _d >> s;
__uint128_t e = _e >> s;
for (uint32_t i = 0; i < n; ++i) {
if (a >= b) {
a -= b;
b = ((b + b) & c) + (b & e) + d;
} else {
b = ((b + b) & c) + (b & e);
}
a <<= 2;
}
*remain = (uint64_t)(a >> t);
return ((uint64_t)(b >> t));
}
// 64bit整数平方根(固定ループ回数) -> (floor(sqrt(iv)), remain)
uint64_t isqrt64f(const uint64_t iv, uint64_t* remain) {
return isqrt64i(iv, remain, 0);
}
// 64bit整数平方根(固定ループ回数) -> remain
uint64_t isqrt64f_remain(const uint64_t iv) {
uint64_t remain;
isqrt64i(iv, &remain, 0);
return remain;
}
// 64bit整数平方根(可変ループ回数) -> (floor(sqrt(iv)), remain)
uint64_t isqrt64d(const uint64_t iv, uint64_t* remain) {
return isqrt64i(iv, remain, __builtin_clzll(iv));
}
// Jacobi symbol: ヤコビ記号
int32_t jacobi(const int64_t sa, uint64_t n) {
uint64_t a; int32_t j;
if (sa >= 0) { a = (uint64_t)(sa); j = 1; } else { a = (uint64_t)(-sa); j = ((n & 3) == 3) ? -1 : 1; }
while (a > 0) {
int ba = __builtin_ctzll(a);
a >>= ba;
if (((n & 7) == 3 || (n & 7) == 5) && (ba & 1) != 0) { j = -j; }
if ((a & n & 3) == 3) { j = -j; }
uint64_t t = n; n = a; a = t; a %= n;
if (a > (n >> 1)) {
a = n - a;
if ((n & 3) == 3) { j = -j; }
}
}
return ((n == 1) ? j : 0);
}
bool primetest_u64_miller_sub(const U64Mont *const mont, uint64_t a) {
if (a >= mont->n) { a %= mont->n; if (a == 0) { return true; } }
uint64_t tr = mont->pow(mont->ar(a), mont->d);
if (tr == mont->r || tr == mont->rn) { return true; }
for (uint32_t j = 1; j < mont->k; ++j) {
tr = mont->mrmul(tr, tr);
if (tr == mont->rn) { return true; }
}
return false;
}
// Miller-Rabin primality test (base 2)
// strong pseudoprimes to base 2 ( https://oeis.org/A001262 ): 2047,3277,4033,4681,8321,15841,29341,42799,49141,52633,...
bool primetest_u64_miller_base2_sub(const U64Mont *const mont) {
uint64_t tr = mont->pow(mont->add(mont->r, mont->r), mont->d);
if (tr == mont->r || tr == mont->rn) { return true; }
for (uint32_t j = 1; j < mont->k; ++j) {
tr = mont->mrmul(tr, tr);
if (tr == mont->rn) { return true; }
}
return false;
}
// Lucas primality test
// strong Lucas pseudoprimes ( https://oeis.org/A217255 ): 5459,5777,10877,16109,18971,22499,24569,25199,40309,58519,...
bool primetest_u64_lucas_sub(const U64Mont *const mont) {
uint64_t n = mont->n;
int64_t d = 5;
for (int i = 0; i < 64; ++i) {
if (jacobi(d, n) == -1) { break; }
if (i == 32 && isqrt64f_remain(n) == 0) { return false; }
if ((i & 1) == 0) { d = -(d + 2); } else { d = 2 - d; }
assert(i < 63);
}
uint64_t qm = mont->ar((d < 0) ? ((((uint64_t)(1 - d)) >> 2) % n) : (n - ((((uint64_t)(d - 1)) >> 2) % n)));
uint64_t k = (n + 1) << __builtin_clzll(n + 1); // n = u64::MAX の時の挙動は?
uint64_t um = mont->r;
uint64_t vm = mont->r;
uint64_t qn = qm;
uint64_t dm = mont->ar((d < 0) ? (n - (((uint64_t)(-d)) % n)) : (((uint64_t)(d)) % n));
for (k <<= 1; k > 0; k <<= 1) {
um = mont->mrmul(um, vm);
vm = mont->sub(mont->mrmul(vm, vm), mont->add(qn, qn));
qn = mont->mrmul(qn, qn);
if ((k >> 63) != 0) {
uint64_t uu = mont->add(um, vm);
uu = mont->div2(uu);
vm = mont->add(mont->mrmul(dm, um), vm);
vm = mont->div2(vm);
um = uu;
qn = mont->mrmul(qn, qm);
}
}
if (um == 0 || vm == 0) { return true; }
uint64_t x = ((n + 1) & (~n)); // n = u64::MAX の時の挙動は?
for (x >>= 1; x > 0; x >>= 1) {
um = mont->mrmul(um, vm);
vm = mont->sub(mont->mrmul(vm, vm), mont->add(qn, qn));
if (vm == 0) { return true; }
qn = mont->mrmul(qn, qn);
}
return false;
}
// Baillie–PSW primarity test
bool primetest_u64_bpsw_sub(const U64Mont *const mont) {
return primetest_u64_miller_base2_sub(mont) && primetest_u64_lucas_sub(mont);
}
// Baillie–PSW primarity test
bool primetest_u64_bpsw(uint64_t n) {
if (n == 2) { return true; }
if (n == 1 || (n & 1) == 0) { return false; }
U64Mont mont(n);
return primetest_u64_bpsw_sub(&mont);
}
U64Mont u64mont_new(uint64_t n) { return U64Mont(n); }
uint64_t u64mont_add(const U64Mont *const mont, uint64_t ar, uint64_t br) { return mont->add(ar, br); }
uint64_t u64mont_sub(const U64Mont *const mont, uint64_t ar, uint64_t br) { return mont->sub(ar, br); }
uint64_t u64mont_div2(const U64Mont *const mont, uint64_t ar) { return mont->div2(ar); }
uint64_t u64mont_mrmul(const U64Mont *const mont, uint64_t ar, uint64_t br) { return mont->mrmul(ar, br); }
uint64_t u64mont_mr(const U64Mont *const mont, uint64_t a) { return mont->mr(a); }
uint64_t u64mont_ar(const U64Mont *const mont, uint64_t a) { return mont->ar(a); }
uint64_t u64mont_pow(const U64Mont *const mont, uint64_t ar, uint64_t b) { return mont->pow(ar, b); }
int ctzll(unsigned long long v) { return __builtin_ctzll(v); }
int clzll(unsigned long long v) { return __builtin_clzll(v); }
const uint64_t bases[] = {2,325,9375,28178,450775,9780504,1795265022};
bool primetest_u64_miller_base7_sub(const U64Mont *const mont) {
if (!primetest_u64_miller_base2_sub(mont)) { return false; }
for (const auto& base : bases) {
if (!primetest_u64_miller_sub(mont, base)) { return false; }
}
return true;
}
bool primetest_u64_miller_base7(const uint64_t n) {
if (n == 2) { return true; }
if (n < 2 || (n & 1) == 0) { return false; }
U64Mont mont(n);
return primetest_u64_miller_base7_sub(&mont);
}
int main(int, char**) {
struct timespec time_monotonic_start, time_process_start, time_monotonic_end, time_process_end;
clock_gettime(CLOCK_MONOTONIC, &time_monotonic_start);
clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time_process_start);
int n;
unsigned long long x;
scanf("%d", &n);
for(int i = 0; i < n; ++i) {
scanf("%llu", &x);
//printf("%llu %d\n", x, primetest_u64_miller_base7((uint64_t)x) ? 1 : 0);
printf("%llu %d\n", x, primetest_u64_bpsw((uint64_t)x) ? 1 : 0);
}
clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time_process_end);
clock_gettime(CLOCK_MONOTONIC, &time_monotonic_end);
double d_sec_monotonic =
(double)(time_monotonic_end.tv_sec - time_monotonic_start.tv_sec) +
(double)(time_monotonic_end.tv_nsec - time_monotonic_start.tv_nsec) / (1000 * 1000 * 1000);
double d_sec_process =
(double)(time_process_end.tv_sec - time_process_start.tv_sec) +
(double)(time_process_end.tv_nsec - time_process_start.tv_nsec) / (1000 * 1000 * 1000);
fprintf(stderr, "time_monotonic:%f, time_process:%f\n", d_sec_monotonic, d_sec_process);
}