結果
| 問題 |
No.8030 ミラー・ラビン素数判定法のテスト
|
| ユーザー |
👑 |
| 提出日時 | 2022-08-27 20:20:42 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 33 ms / 9,973 ms |
| コード長 | 5,733 bytes |
| コンパイル時間 | 612 ms |
| コンパイル使用メモリ | 30,204 KB |
| 最終ジャッジ日時 | 2025-01-31 06:32:33 |
|
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 10 |
コンパイルメッセージ
main.cpp: In function ‘int main(int, char**)’:
main.cpp:144:10: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
144 | scanf("%d", &n);
| ~~~~~^~~~~~~~~~
main.cpp:147:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
147 | scanf("%llu", &x);
| ~~~~~^~~~~~~~~~~~
ソースコード
#include <cstdbool>
#include <cstdint>
#include <cstdio>
class U64Mont {
private:
static uint64_t _ni(uint64_t n) { // n * ni == 1 (mod 2**64)
// // n is odd number, n = 2*k+1, n >= 1, n < 2**64, k is non-negative integer, k >= 0, k < 2**63
// ni0 := n; // = 2*k+1 = (1+(2**2)*((k*(k+1))**1))/(2*k+1)
uint64_t ni = n;
// ni1 := ni0 * (2 - (n * ni0)); // = (1-(2**4)*((k*(k+1))**2))/(2*k+1)
// ni2 := ni1 * (2 - (n * ni1)); // = (1-(2**8)*((k*(k+1))**4))/(2*k+1)
// ni3 := ni2 * (2 - (n * ni2)); // = (1-(2**16)*((k*(k+1))**8))/(2*k+1)
// ni4 := ni3 * (2 - (n * ni3)); // = (1-(2**32)*((k*(k+1))**16))/(2*k+1)
// ni5 := ni4 * (2 - (n * ni4)); // = (1-(2**64)*((k*(k+1))**32))/(2*k+1)
// // (n * ni5) mod 2**64 = ((2*k+1) * ni5) mod 2**64 = 1 mod 2**64
for (int i = 0; i < 5; ++i) {
ni = ni * (2 - n * ni);
}
return ni;
}
static uint64_t _n1(uint64_t n) { // == n - 1
return n - 1;
}
static uint64_t _nh(uint64_t n) { // == (n + 1) / 2
return (n >> 1) + 1;
}
static uint64_t _r(uint64_t n) { // == 2**64 (mod n)
return (-n) % n;
}
static uint64_t _rn(uint64_t n) { // == -(2**64) (mod n)
return n - _r(n);
}
static uint64_t _r2(uint64_t n) { // == 2**128 (mod n)
return (uint64_t)((-((__uint128_t)n)) % ((__uint128_t)n));
}
static uint32_t _k(uint64_t n) { // == trailing_zeros(n1)
// https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html#Other-Builtins
return __builtin_ctzll(n - 1);
}
static uint64_t _d(uint64_t n) { // == n1 >> k // n == 2**k * d + 1
return (n - 1) >> _k(n);
}
public:
const uint64_t n; // == n
const uint64_t ni; // n * ni == 1 (mod 2**64)
const uint64_t n1; // == n - 1
const uint64_t nh; // == (n + 1) / 2
const uint64_t r; // == 2**64 (mod n)
const uint64_t rn; // == -(2**64) (mod n)
const uint64_t r2; // == 2**128 (mod n)
const uint64_t d; // == n1 >> k // n == 2**k * d + 1
const uint32_t k; // == trailing_zeros(n1)
U64Mont(uint64_t n)
: n(n), ni(_ni(n)), n1(_n1(n)), nh(_nh(n)), r(_r(n)), rn(_rn(n)), r2(_r2(n)), d(_d(n)), k(_k(n)) {}
uint64_t add(uint64_t a, uint64_t b) {
// add(a, b) == a + b (mod n)
unsigned long long t, u;
// https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
bool f1 = __builtin_uaddll_overflow(a, b, &t);
bool f2 = __builtin_usubll_overflow(t, f1 ? n : 0, &u);
return f2 ? t : u;
}
uint64_t sub(uint64_t a, uint64_t b) {
// sub(a, b) == a - b (mod n)
unsigned long long t;
// https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
bool f = __builtin_usubll_overflow(a, b, &t);
return t + (f ? n : 0);
}
uint64_t div2(uint64_t ar) {
// div2(ar) == ar / 2 (mod n)
if ((ar & 1) == 0) {
return (ar >> 1);
} else {
return (ar >> 1) + nh;
}
}
uint64_t mrmul(uint64_t ar, uint64_t br) {
// mrmul(ar, br) == (ar * br) / r (mod n)
// R == 2**64
// gcd(N, R) == 1
// N * ni mod R == 1
// 0 <= ar < N < R
// 0 <= br < N < R
// T := ar * br
// t := floor(T / R) - floor(((T * ni mod R) * N) / R)
// if t < 0 then return t + N else return t
__uint128_t t = ((__uint128_t)ar) * ((__uint128_t)br);
uint64_t u = (uint64_t)(t >> 64);
uint64_t v = (uint64_t)((((__uint128_t)(((uint64_t)t) * ni)) * ((__uint128_t)n)) >> 64);
unsigned long long w;
// https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
bool f = __builtin_usubll_overflow(u, v, &w);
return w + (f ? n : 0);
}
uint64_t mr(uint64_t ar) {
// mr(ar) == ar / r (mod n)
// R == 2**64
// gcd(N, R) == 1
// N * ni mod R == 1
// 0 <= ar < N < R
// t := floor(ar / R) - floor(((ar * ni mod R) * N) / R)
// if t < 0 then return t + N else return t
uint64_t v = (uint64_t)((((__uint128_t)(ar * ni)) * ((__uint128_t)n)) >> 64);
return v == 0 ? 0 : n - v;
}
uint64_t ar(uint64_t a) {
// ar(a) == a * r (mod n)
return mrmul(a, r2);
}
uint64_t pow(uint64_t ar, uint64_t b) {
// pow(ar, b) == ((ar / r) ** b) * r (mod n)
uint64_t t = ((b & 1) == 0) ? r : ar;
b >>= 1;
while (b != 0) {
ar = mrmul(ar, ar);
if ((b & 1) != 0) { t = mrmul(t, ar); }
b >>= 1;
}
return t;
}
};
const uint64_t bases[] = {2,325,9375,28178,450775,9780504,1795265022};
bool miller_rabin(uint64_t n) {
// Deterministic variants of the Miller-Rabin primality test
// http://miller-rabin.appspot.com/
if (n == 2) { return true; }
if (n < 2 || (n & 1) == 0) { return false; }
U64Mont mont(n);
for (int i = 0; i < 7; ++i) {
uint64_t a = bases[i];
if (a >= n) { a %= n; if (a == 0) { continue; } }
uint64_t ar = mont.ar(a);
uint64_t tr = mont.pow(ar, mont.d);
if (tr == mont.r || tr == mont.rn) { continue; }
for (int j = 1; j < mont.k; ++j) { tr = mont.mrmul(tr, tr); if (tr == mont.rn) { goto cont; } }
return false;
cont: continue;
}
return true;
}
int main(int argc, char *argv[]) {
int n;
scanf("%d", &n);
for(int i = 0; i < n; ++i) {
unsigned long long x;
scanf("%llu", &x);
printf("%llu %d\n", x, miller_rabin((uint64_t)x) ? 1 : 0);
}
}