結果
問題 |
No.8030 ミラー・ラビン素数判定法のテスト
|
ユーザー |
|
提出日時 | 2025-05-19 23:24:28 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 70 ms / 9,973 ms |
コード長 | 5,805 bytes |
コンパイル時間 | 1,120 ms |
コンパイル使用メモリ | 103,656 KB |
実行使用メモリ | 6,272 KB |
最終ジャッジ日時 | 2025-05-19 23:24:31 |
合計ジャッジ時間 | 3,269 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 10 |
ソースコード
#include <algorithm> #include <bit> #include <cassert> #include <concepts> #include <cstdint> #include <iostream> #include <limits> #include <type_traits> #include <vector> namespace wasd314 { struct dynamic_modint2 { using U1 = std::uint32_t; using U2 = std::uint64_t; using I2 = std::make_signed_t<U2>; using mint = dynamic_modint2; static constexpr int bits1 = std::numeric_limits<U1>::digits; static constexpr int bits2 = std::numeric_limits<U2>::digits; U2 rx; // N static inline U2 mod; // R^-1 % N static inline U2 R_1; // R^1 % N static inline U2 R1; // R^2 % N static inline U2 R2; // -(N^-1) % R static inline U2 N_; // `(x1 * x2) >> bits2` static U2 multiply_high(U2 x, U2 y) { U1 hx = x >> bits1, lx = x; U1 hy = y >> bits1, ly = y; U2 ans = U2(hx) * hy; ans += (U2(hx) * ly) >> bits1; ans += (U2(lx) * hy) >> bits1; U2 m = U2(hx * ly) + U2(lx * hy) + ((U2(lx) * ly) >> bits1); ans += m >> bits1; return ans; } private: static void set_N_() { U2 n_inv = mod; for (int bit = 2; bit < bits2; bit <<= 1) { n_inv *= 2 - n_inv * mod; } N_ = -n_inv; } static void set_R1() { R1 = -mod % mod; } static void set_R2() { R2 = R1; for (int _ = 0; _ < bits2; ++_) { R2 <<= 1; if (R2 >= mod) R2 -= mod; } } static void set_R_1() { R_1 = 1 + multiply_high(mod, N_); } public: static void set_mod(U2 new_mod) { assert(I2(new_mod) > 0); assert(new_mod & 1); mod = new_mod; set_N_(); set_R1(); set_R2(); set_R_1(); } static U2 safe_mod(I2 x) { x %= I2(mod); if (x < 0) x += I2(mod); return x; } // MR(x) static U2 reduce(const U2 &x) { return multiply_reduce(x, 1); } // MR(x * y) static U2 multiply_reduce(const U2 &x, const U2 &y) { U2 t_ = x * y * N_; U2 t = multiply_high(x, y) + multiply_high(t_, mod) + (x * y != 0); return t < mod ? t : t - mod; } dynamic_modint2(const I2 &x) : rx(multiply_reduce(safe_mod(x), R2)) {} // for literal dynamic_modint2(I2 &&x) : rx(multiply_reduce(safe_mod(x), R2)) {} private: dynamic_modint2(const U2 &x, auto) : rx(x) {} public: static mint raw(const U2 &x) { return mint(x, 0); } U2 val() const { return reduce(rx); } mint pow(U2 e) const { mint ans = raw(R1), b(*this); while (e) { if (e & 1) ans *= b; b *= b; e >>= 1; } return ans; } mint &operator+=(const mint &o) { rx += o.rx; if (rx >= mod) rx -= mod; return *this; } mint &operator-=(const mint &o) { if (__builtin_sub_overflow(rx, o.rx, &rx)) rx += mod; return *this; } mint &operator*=(const mint &o) { rx = multiply_reduce(rx, o.rx); return *this; } mint operator+() const { return mint(*this); } mint operator-() const { return mint::raw(0) -= *this; } }; using mint = dynamic_modint2; bool operator==(const mint &x, const mint &y) { return x.rx == y.rx; } bool operator!=(const mint &x, const mint &y) { return !(x == y); } mint operator+(const mint &x, const mint &y) { return mint(x) += y; } mint operator-(const mint &x, const mint &y) { return mint(x) -= y; } mint operator*(const mint &x, const mint &y) { return mint(x) *= y; } using lint = long long; using u64 = std::uint64_t; bool is_prime(u64 n) { if (n < 2) return false; for (u64 p : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { if (n == p) return true; if (n % p == 0) return false; } if (n < 41 * 41) return true; mint::set_mod(n); const mint one = mint::raw(mint::R1), neg_one = -one; auto test_miller_rabin = [&](const std::vector<lint> &bases) { int e = std::countr_zero(n - 1); u64 o = n >> e; for (lint b : bases) { mint x = mint(b).pow(o); if (x == one) continue; for (int ei = 0; ei < e; ++ei) { mint y = x * x; if (y == one) { if (x == neg_one) break; return false; } x = y; if (ei == e - 1) return false; } } return true; }; if (n < 2047) return test_miller_rabin({2}); if (n < 9080191) return test_miller_rabin({31, 73}); if (n < 4759123141) return test_miller_rabin({2, 7, 61}); if (n < 1122004669633) return test_miller_rabin({2, 13, 23, 1662803}); if (n < 3770579582154547) return test_miller_rabin({2, 880937, 2570940, 610386380, 4130785767}); return test_miller_rabin({2, 325, 9375, 28178, 450775, 9780504, 1795265022}); } } // namespace wasd314 int main() { using namespace std; using namespace wasd314; int q; cin >> q; for (int _ = 0; _ < q; ++_) { lint n; cin >> n; cout << n << ' ' << is_prime(n) << "\n"; } }