結果

問題 No.8030 ミラー・ラビン素数判定法のテスト
ユーザー wasd314
提出日時 2025-05-19 23:24:28
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 70 ms / 9,973 ms
コード長 5,805 bytes
コンパイル時間 1,120 ms
コンパイル使用メモリ 103,656 KB
実行使用メモリ 6,272 KB
最終ジャッジ日時 2025-05-19 23:24:31
合計ジャッジ時間 3,269 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <bit>
#include <cassert>
#include <concepts>
#include <cstdint>
#include <iostream>
#include <limits>
#include <type_traits>
#include <vector>

namespace wasd314
{
    struct dynamic_modint2 {
        using U1 = std::uint32_t;
        using U2 = std::uint64_t;
        using I2 = std::make_signed_t<U2>;
        using mint = dynamic_modint2;

        static constexpr int bits1 = std::numeric_limits<U1>::digits;
        static constexpr int bits2 = std::numeric_limits<U2>::digits;

        U2 rx;

        // N
        static inline U2 mod;
        // R^-1 % N
        static inline U2 R_1;
        // R^1 % N
        static inline U2 R1;
        // R^2 % N
        static inline U2 R2;
        // -(N^-1) % R
        static inline U2 N_;

        // `(x1 * x2) >> bits2`
        static U2 multiply_high(U2 x, U2 y)
        {
            U1 hx = x >> bits1, lx = x;
            U1 hy = y >> bits1, ly = y;
            U2 ans = U2(hx) * hy;
            ans += (U2(hx) * ly) >> bits1;
            ans += (U2(lx) * hy) >> bits1;
            U2 m = U2(hx * ly) + U2(lx * hy) + ((U2(lx) * ly) >> bits1);
            ans += m >> bits1;
            return ans;
        }

       private:
        static void set_N_()
        {
            U2 n_inv = mod;
            for (int bit = 2; bit < bits2; bit <<= 1) {
                n_inv *= 2 - n_inv * mod;
            }
            N_ = -n_inv;
        }
        static void set_R1() { R1 = -mod % mod; }
        static void set_R2()
        {
            R2 = R1;
            for (int _ = 0; _ < bits2; ++_) {
                R2 <<= 1;
                if (R2 >= mod) R2 -= mod;
            }
        }
        static void set_R_1() { R_1 = 1 + multiply_high(mod, N_); }

       public:
        static void set_mod(U2 new_mod)
        {
            assert(I2(new_mod) > 0);
            assert(new_mod & 1);
            mod = new_mod;
            set_N_();
            set_R1();
            set_R2();
            set_R_1();
        }

        static U2 safe_mod(I2 x)
        {
            x %= I2(mod);
            if (x < 0) x += I2(mod);
            return x;
        }

        // MR(x)
        static U2 reduce(const U2 &x) { return multiply_reduce(x, 1); }
        // MR(x * y)
        static U2 multiply_reduce(const U2 &x, const U2 &y)
        {
            U2 t_ = x * y * N_;
            U2 t = multiply_high(x, y) + multiply_high(t_, mod) + (x * y != 0);
            return t < mod ? t : t - mod;
        }

        dynamic_modint2(const I2 &x) : rx(multiply_reduce(safe_mod(x), R2)) {}
        // for literal
        dynamic_modint2(I2 &&x) : rx(multiply_reduce(safe_mod(x), R2)) {}

       private:
        dynamic_modint2(const U2 &x, auto) : rx(x) {}

       public:
        static mint raw(const U2 &x) { return mint(x, 0); }

        U2 val() const { return reduce(rx); }
        mint pow(U2 e) const
        {
            mint ans = raw(R1), b(*this);
            while (e) {
                if (e & 1) ans *= b;
                b *= b;
                e >>= 1;
            }
            return ans;
        }

        mint &operator+=(const mint &o)
        {
            rx += o.rx;
            if (rx >= mod) rx -= mod;
            return *this;
        }
        mint &operator-=(const mint &o)
        {
            if (__builtin_sub_overflow(rx, o.rx, &rx)) rx += mod;
            return *this;
        }
        mint &operator*=(const mint &o)
        {
            rx = multiply_reduce(rx, o.rx);
            return *this;
        }
        mint operator+() const { return mint(*this); }
        mint operator-() const { return mint::raw(0) -= *this; }
    };
    using mint = dynamic_modint2;
    bool operator==(const mint &x, const mint &y) { return x.rx == y.rx; }
    bool operator!=(const mint &x, const mint &y) { return !(x == y); }
    mint operator+(const mint &x, const mint &y) { return mint(x) += y; }
    mint operator-(const mint &x, const mint &y) { return mint(x) -= y; }
    mint operator*(const mint &x, const mint &y) { return mint(x) *= y; }

    using lint = long long;
    using u64 = std::uint64_t;
    bool is_prime(u64 n)
    {
        if (n < 2) return false;
        for (u64 p : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) {
            if (n == p) return true;
            if (n % p == 0) return false;
        }
        if (n < 41 * 41) return true;
        mint::set_mod(n);

        const mint one = mint::raw(mint::R1), neg_one = -one;

        auto test_miller_rabin = [&](const std::vector<lint> &bases) {
            int e = std::countr_zero(n - 1);
            u64 o = n >> e;
            for (lint b : bases) {
                mint x = mint(b).pow(o);
                if (x == one) continue;
                for (int ei = 0; ei < e; ++ei) {
                    mint y = x * x;
                    if (y == one) {
                        if (x == neg_one) break;
                        return false;
                    }
                    x = y;
                    if (ei == e - 1) return false;
                }
            }
            return true;
        };
        if (n < 2047) return test_miller_rabin({2});
        if (n < 9080191) return test_miller_rabin({31, 73});
        if (n < 4759123141) return test_miller_rabin({2, 7, 61});
        if (n < 1122004669633) return test_miller_rabin({2, 13, 23, 1662803});
        if (n < 3770579582154547) return test_miller_rabin({2, 880937, 2570940, 610386380, 4130785767});
        return test_miller_rabin({2, 325, 9375, 28178, 450775, 9780504, 1795265022});
    }
}  // namespace wasd314


int main()
{
    using namespace std;
    using namespace wasd314;
    int q;
    cin >> q;
    for (int _ = 0; _ < q; ++_) {
        lint n;
        cin >> n;
        cout << n << ' ' << is_prime(n) << "\n";
    }
}
0