結果
問題 | No.3030 ミラー・ラビン素数判定法のテスト |
ユーザー | 👑 Mizar |
提出日時 | 2022-08-29 04:08:29 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 33 ms / 9,973 ms |
コード長 | 7,345 bytes |
コンパイル時間 | 427 ms |
コンパイル使用メモリ | 36,224 KB |
実行使用メモリ | 5,248 KB |
最終ジャッジ日時 | 2024-11-16 23:59:59 |
合計ジャッジ時間 | 1,071 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
5,248 KB |
testcase_01 | AC | 1 ms
5,248 KB |
testcase_02 | AC | 2 ms
5,248 KB |
testcase_03 | AC | 2 ms
5,248 KB |
testcase_04 | AC | 21 ms
5,248 KB |
testcase_05 | AC | 22 ms
5,248 KB |
testcase_06 | AC | 12 ms
5,248 KB |
testcase_07 | AC | 12 ms
5,248 KB |
testcase_08 | AC | 12 ms
5,248 KB |
testcase_09 | AC | 33 ms
5,248 KB |
ソースコード
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,bmi2,lzcnt,tune=native") //#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native") #pragma GCC optimize("O3") //#pragma GCC optimize("unroll-loops") #ifndef NDEBUG #define NDEBUG #endif #include <cassert> #include <ctime> #include <cstdio> #include <cstdbool> #include <cstdint> class U64Mont { private: static uint64_t _ni(uint64_t n) { // n * ni == 1 (mod 2**64) // // n is odd number, n = 2*k+1, n >= 1, n < 2**64, k is non-negative integer, k >= 0, k < 2**63 // ni0 := n; // = 2*k+1 = (1+(2**2)*((k*(k+1))**1))/(2*k+1) uint64_t ni = n; // ni1 := ni0 * (2 - (n * ni0)); // = (1-(2**4)*((k*(k+1))**2))/(2*k+1) // ni2 := ni1 * (2 - (n * ni1)); // = (1-(2**8)*((k*(k+1))**4))/(2*k+1) // ni3 := ni2 * (2 - (n * ni2)); // = (1-(2**16)*((k*(k+1))**8))/(2*k+1) // ni4 := ni3 * (2 - (n * ni3)); // = (1-(2**32)*((k*(k+1))**16))/(2*k+1) // ni5 := ni4 * (2 - (n * ni4)); // = (1-(2**64)*((k*(k+1))**32))/(2*k+1) // // (n * ni5) mod 2**64 = ((2*k+1) * ni5) mod 2**64 = 1 mod 2**64 for (int i = 0; i < 5; ++i) { ni = ni * (2 - n * ni); } assert(n * ni == 1); // n * ni == 1 (mod 2**64) return ni; } static uint64_t _n1(uint64_t n) { // == n - 1 return n - 1; } static uint64_t _nh(uint64_t n) { // == (n + 1) / 2 return (n >> 1) + 1; } static uint64_t _r(uint64_t n) { // == 2**64 (mod n) return (-n) % n; } static uint64_t _rn(uint64_t n) { // == -1 * (2**64) (mod n) return n - _r(n); } static uint64_t _r2(uint64_t n) { // == 2**128 (mod n) return (uint64_t)((-((__uint128_t)n)) % ((__uint128_t)n)); } static uint32_t _k(uint64_t n) { // == trailing_zeros(n - 1) // https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html#Other-Builtins return __builtin_ctzll(_n1(n)); } static uint64_t _d(uint64_t n) { // == (n - 1) >> trailing_zeros(n - 1) // n == 2**k * d + 1 return _n1(n) >> _k(n); } public: const uint64_t n; // == n const uint64_t ni; // n * ni == 1 (mod 2**64) const uint64_t n1; // == n - 1 const uint64_t nh; // == (n + 1) / 2 const uint64_t r; // == 2**64 (mod n) const uint64_t rn; // == -1 * (2**64) (mod n) const uint64_t r2; // == 2**128 (mod n) const uint64_t d; // == (n - 1) >> trailing_zeros(n - 1) // n == 2**k * d + 1 const uint32_t k; // == trailing_zeros(n - 1) U64Mont(uint64_t n) : n(n), ni(_ni(n)), n1(_n1(n)), nh(_nh(n)), r(_r(n)), rn(_rn(n)), r2(_r2(n)), d(_d(n)), k(_k(n)) { assert((n & 1) == 1); } uint64_t add(uint64_t a, uint64_t b) { // add(a, b) == a + b (mod n) assert(a < n); assert(b < n); unsigned long long t, u; // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins bool f1 = __builtin_uaddll_overflow(a, b, &t); bool f2 = __builtin_usubll_overflow(t, f1 ? n : 0, &u); return f2 ? t : u; } uint64_t sub(uint64_t a, uint64_t b) { // sub(a, b) == a - b (mod n) assert(a < n); assert(b < n); unsigned long long t; // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins bool f = __builtin_usubll_overflow(a, b, &t); return t + (f ? n : 0); } uint64_t div2(uint64_t ar) { // div2(ar) == ar / 2 (mod n) assert(ar < n); return (ar >> 1) + ((ar & 1) == 0 ? 0 : nh); } uint64_t mrmul(uint64_t ar, uint64_t br) { // mrmul(ar, br) == (ar * br) / r (mod n) // R == 2**64 // gcd(N, R) == 1 // N * ni mod R == 1 // 0 <= ar < N < R // 0 <= br < N < R // T := ar * br // t := floor(T / R) - floor(((T * ni mod R) * N) / R) // if t < 0 then return t + N else return t assert(ar < n); assert(br < n); __uint128_t t = ((__uint128_t)ar) * ((__uint128_t)br); unsigned long long w; // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins bool f = __builtin_usubll_overflow((unsigned long long)(t >> 64), (unsigned long long)((((__uint128_t)(((uint64_t)t) * ni)) * ((__uint128_t)n)) >> 64), &w); return w + (f ? n : 0); } uint64_t mr(uint64_t ar) { // mr(ar) == ar / r (mod n) // R == 2**64 // gcd(N, R) == 1 // N * ni mod R == 1 // 0 <= ar < N < R // t := floor(ar / R) - floor(((ar * ni mod R) * N) / R) // if t < 0 then return t + N else return t assert(ar < n); uint64_t v = (uint64_t)((((__uint128_t)(ar * ni)) * ((__uint128_t)n)) >> 64); return v == 0 ? 0 : n - v; } uint64_t ar(uint64_t a) { // ar(a) == a * r (mod n) assert(a < n); return mrmul(a, r2); } uint64_t pow(uint64_t ar, uint64_t b) { // pow(ar, b) == ((ar / r) ** b) * r (mod n) assert(ar < n); if (b == 0) { return r; } for (; (b & 1) == 0; b >>= 1) { ar = mrmul(ar, ar); } uint64_t tr = ar; for (b >>= 1; b != 0; b >>= 1) { ar = mrmul(ar, ar); if ((b & 1) != 0) { tr = mrmul(tr, ar); } } return tr; } }; U64Mont u64mont_new(uint64_t n) { return U64Mont(n); } uint64_t u64mont_add(U64Mont *mont, uint64_t ar, uint64_t br) { return mont->add(ar, br); } uint64_t u64mont_sub(U64Mont *mont, uint64_t ar, uint64_t br) { return mont->sub(ar, br); } uint64_t u64mont_div2(U64Mont *mont, uint64_t ar) { return mont->div2(ar); } uint64_t u64mont_mrmul(U64Mont *mont, uint64_t ar, uint64_t br) { return mont->mrmul(ar, br); } uint64_t u64mont_mr(U64Mont *mont, uint64_t a) { return mont->mr(a); } uint64_t u64mont_ar(U64Mont *mont, uint64_t a) { return mont->ar(a); } uint64_t u64mont_pow(U64Mont *mont, uint64_t ar, uint64_t b) { return mont->pow(ar, b); } const uint64_t bases[] = {2,325,9375,28178,450775,9780504,1795265022}; bool miller_rabin(uint64_t n) { // Deterministic variants of the Miller-Rabin primality test // http://miller-rabin.appspot.com/ if (n == 2) { return true; } if (n < 2 || (n & 1) == 0) { return false; } U64Mont mont(n); for (const auto& base : bases) { uint64_t a = base; if (a >= n) { a %= n; if (a == 0) { continue; } } uint64_t tr = mont.pow(mont.ar(a), mont.d); if (tr == mont.r) { continue; } for (uint32_t j = 1; tr != mont.rn; ++j) { if (j >= mont.k) { return false; } tr = mont.mrmul(tr, tr); } } return true; } int main(int argc, char *argv[]) { struct timespec start_time, end_time; clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start_time); int n; scanf("%d", &n); for(int i = 0; i < n; ++i) { unsigned long long x; scanf("%llu", &x); printf("%llu %d\n", x, miller_rabin((uint64_t)x) ? 1 : 0); } clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end_time); double d_sec = (double)(end_time.tv_sec - start_time.tv_sec) + (double)(end_time.tv_nsec - start_time.tv_nsec) / (1000 * 1000 * 1000); fprintf(stderr, "time:%f\n", d_sec); }