結果

問題 No.8030 ミラー・ラビン素数判定法のテスト
ユーザー wasd314
提出日時 2025-05-20 01:24:45
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 53 ms / 9,973 ms
コード長 4,252 bytes
コンパイル時間 1,363 ms
コンパイル使用メモリ 103,508 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-05-20 01:24:48
合計ジャッジ時間 2,389 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <bit>
#include <cassert>
#include <concepts>
#include <cstdint>
#include <iostream>
#include <limits>
#include <type_traits>
#include <vector>

namespace wasd314
{
    using i64 = std::int64_t;
    using u64 = std::uint64_t;
    using lint = i64;

    struct dynamic_mod3 {
        using U2 = u64;
        using U4 = __uint128_t;
        using I2 = std::make_signed_t<U2>;

        // N
        U2 mod;
        // R^-1 % N
        U2 R_1;
        // R^1 % N
        U2 R1;
        // R^2 % N
        U2 R2;
        // -(N^-1) % R
        U2 N_;

        static constexpr int bits2 = std::numeric_limits<U2>::digits;

        U2 get_N_() const
        {
            U2 n_inv = mod;
            for (int bits = 2; bits < bits2; bits <<= 1) {
                n_inv *= 2 - n_inv * mod;
            }
            return -n_inv;
        }
        U2 get_R1() const { return -mod % mod; }
        U2 get_R2() const { return -U4(mod) % mod; }
        U2 get_R_1() const { return (1 + U4(mod) * N_) >> bits2; }

        dynamic_mod3(const U2 &mod)
        {
            assert(I2(mod) > 0);
            assert(mod & 1);
            this->mod = mod;
            N_ = get_N_();
            R1 = get_R1();
            R2 = get_R2();
            R_1 = get_R_1();
        }

        U2 safe_mod(I2 x) const
        {
            x %= I2(mod);
            if (x < 0) x += I2(mod);
            return x;
        }
        U2 reduce(const U4 &x) const
        {
            U2 y = (x + U4(U2(x) * N_) * mod) >> bits2;
            return y < mod ? y : y - mod;
        }

        U2 from(const I2 &x) const { return reduce(U4(safe_mod(x)) * R2); }
        U2 from(I2 &&x) const { return reduce(U4(safe_mod(x)) * R2); }

        U2 val(const U2 &rx) const { return reduce(rx); }
        U2 pow(U2 rx, U2 e) const
        {
            U2 ans = R1, b = rx;
            while (e) {
                if (e & 1) ans = reduce(U4(ans) * b);
                b = reduce(U4(b) * b);
                e >>= 1;
            }
            return ans;
        }

        U2 add(const U2 &x, const U2 &y) const
        {
            U2 z = x + y;
            if (z >= mod) z -= mod;
            return z;
        }
        U2 sub(const U2 &x, const U2 &y) const
        {
            U2 z;
            if (__builtin_sub_overflow(x, y, &z)) z += mod;
            return z;
        }
        U2 mul(const U2 &x, const U2 &y) const { return reduce(U4(x) * y); }
        U2 neg(const U2 &x) const { return sub(0, x); }
    };

    bool is_prime(u64 n)
    {
        using U2 = dynamic_mod3::U2;
        if (n < 2) return false;
        for (u64 p : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) {
            if (n == p) return true;
            if (n % p == 0) return false;
        }
        if (n < 41 * 41) return true;

        dynamic_mod3 mont(n);
        const U2 one = mont.R1, neg_one = mont.neg(one);

        auto test_miller_rabin = [&](const std::vector<i64> &bases) {
            int e = std::countr_zero(n - 1);
            u64 o = n >> e;
            for (i64 b : bases) {
                U2 x = mont.pow(mont.from(b), o);
                if (x == one) continue;
                for (int ei = 0; ei < e; ++ei) {
                    U2 y = mont.mul(x, x);
                    if (y == one) {
                        if (x == neg_one) break;
                        return false;
                    }
                    x = y;
                    if (ei == e - 1) return false;
                }
            }
            return true;
        };
        if (n < 2047) return test_miller_rabin({2});
        if (n < 9080191) return test_miller_rabin({31, 73});
        if (n < 4759123141) return test_miller_rabin({2, 7, 61});
        if (n < 1122004669633) return test_miller_rabin({2, 13, 23, 1662803});
        if (n < 3770579582154547) return test_miller_rabin({2, 880937, 2570940, 610386380, 4130785767});
        return test_miller_rabin({2, 325, 9375, 28178, 450775, 9780504, 1795265022});
    }
}  // namespace wasd314


int main()
{
    using namespace std;
    using namespace wasd314;
    int q;
    cin >> q;
    for (int _ = 0; _ < q; ++_) {
        lint n;
        cin >> n;
        cout << n << ' ' << is_prime(n) << "\n";
    }
}
0