結果

問題 No.3030 ミラー・ラビン素数判定法のテスト
ユーザー 👑 MizarMizar
提出日時 2022-09-02 16:24:21
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 21 ms / 9,973 ms
コード長 12,952 bytes
コンパイル時間 672 ms
コンパイル使用メモリ 39,108 KB
実行使用メモリ 5,248 KB
最終ジャッジ日時 2024-11-17 00:04:45
合計ジャッジ時間 1,523 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
5,248 KB
testcase_01 AC 1 ms
5,248 KB
testcase_02 AC 1 ms
5,248 KB
testcase_03 AC 1 ms
5,248 KB
testcase_04 AC 16 ms
5,248 KB
testcase_05 AC 17 ms
5,248 KB
testcase_06 AC 13 ms
5,248 KB
testcase_07 AC 13 ms
5,248 KB
testcase_08 AC 13 ms
5,248 KB
testcase_09 AC 21 ms
5,248 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

//#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,bmi2,lzcnt,tune=native")
//#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
//#pragma GCC target ("sse4")
#pragma GCC optimize("O3")
//#pragma GCC optimize ("tree-vectorize")
//#pragma GCC optimize("unroll-loops")
#ifndef NDEBUG
#define NDEBUG
#endif
#include <cassert>
#include <ctime>
#include <cstdio>
#include <cstdbool>
#include <cstdint>
using i32 = int32_t;
using i64 = int64_t;
using u32 = uint32_t;
using u64 = uint64_t;
using u128 = __uint128_t;
using f64 = double;
class U64Mont {
private:
    static u64 _ni(u64 n) noexcept { // n * ni == 1 (mod 2**64)
        // // n is odd number, n = 2*k+1, n >= 1, n < 2**64, k is non-negative integer, k >= 0, k < 2**63
        // ni0 := n; // = 2*k+1 = (1+(2**3)*((k*(k+1)/2)**1))/(2*k+1)
        u64 ni = n;
        // ni1 := ni0 * (2 - (n * ni0)); // = (1-(2**6)*((k*(k+1)/2)**2))/(2*k+1)
        // ni2 := ni1 * (2 - (n * ni1)); // = (1-(2**12)*((k*(k+1)/2)**4))/(2*k+1)
        // ni3 := ni2 * (2 - (n * ni2)); // = (1-(2**24)*((k*(k+1)/2)**8))/(2*k+1)
        // ni4 := ni3 * (2 - (n * ni3)); // = (1-(2**48)*((k*(k+1)/2)**16))/(2*k+1)
        // ni5 := ni4 * (2 - (n * ni4)); // = (1-(2**96)*((k*(k+1)/2)**32))/(2*k+1)
        // // (n * ni5) mod 2**64 = ((2*k+1) * ni5) mod 2**64 = 1 mod 2**64
        for (int i = 0; i < 5; ++i) {
            ni = ni * (2 - n * ni);
        }
        assert(n * ni == 1); // n * ni == 1 (mod 2**64)
        return ni;
    }
    static u64 _n1(u64 n) noexcept { // == n - 1
        return n - 1;
    }
    static u64 _nh(u64 n) noexcept { // == (n + 1) / 2
        return (n >> 1) + 1;
    }
    static u64 _r(u64 n) noexcept { // == 2**64 (mod n)
        return (-n) % n;
    }
    static u64 _rn(u64 n) noexcept { // == -1 * (2**64) (mod n)
        return n - _r(n);
    }
    static u64 _r2(u64 n) noexcept { // == 2**128 (mod n)
        return (u64)((-((u128)n)) % ((u128)n));
    }
    static u32 _k(u64 n) noexcept { // == trailing_zeros(n - 1)
        // https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html#Other-Builtins
        return __builtin_ctzll(_n1(n));
    }
    static u64 _d(u64 n) noexcept { // == (n - 1) >> trailing_zeros(n - 1) // n == 2**k * d + 1
        return _n1(n) >> _k(n);
    }
public:
    const u64 n; // == n
    const u64 ni; // n * ni == 1 (mod 2**64)
    const u64 n1; // == n - 1
    const u64 nh; // == (n + 1) / 2
    const u64 r; // == 2**64 (mod n)
    const u64 rn; // == -1 * (2**64) (mod n)
    const u64 r2; // == 2**128 (mod n)
    const u64 d; // == (n - 1) >> trailing_zeros(n - 1) // n == 2**k * d + 1
    const u32 k; // == trailing_zeros(n - 1)
    U64Mont(u64 n) noexcept
        : n(n), ni(_ni(n)), n1(_n1(n)), nh(_nh(n)), r(_r(n)), rn(_rn(n)), r2(_r2(n)), d(_d(n)), k(_k(n))
        { assert((n & 1) == 1); }
    u64 add(u64 a, u64 b) const noexcept {
        // add(a, b) == a + b (mod n)
        assert(a < n);
        assert(b < n);
        unsigned long long t, u;
        // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
        bool fa = __builtin_uaddll_overflow(a, b, &t);
        bool fs = __builtin_usubll_overflow(t, n, &u);
        return (fa ? u : (fs ? t : u));
    }
    u64 sub(u64 a, u64 b) const noexcept {
        // sub(a, b) == a - b (mod n)
        assert(a < n);
        assert(b < n);
        unsigned long long t;
        // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
        bool f = __builtin_usubll_overflow(a, b, &t);
        return (f ? t + n : t);
    }
    u64 div2(u64 ar) const noexcept {
        // div2(ar) == ar / 2 (mod n)
        assert(ar < n);
        return ((ar >> 1) + ((ar & 1) == 0 ? 0 : nh));
    }
    u64 mrmul(u64 ar, u64 br) const noexcept {
        // mrmul(ar, br) == (ar * br) / r (mod n)
        // R == 2**64
        // gcd(N, R) == 1
        // N * ni mod R == 1
        // 0 <= ar < N < R
        // 0 <= br < N < R
        // T := ar * br
        // t := floor(T / R) - floor(((T * ni mod R) * N) / R)
        // if t < 0 then return t + N else return t
        assert(ar < n);
        assert(br < n);
        u128 t = ((u128)ar) * ((u128)br);
        unsigned long long w;
        // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html#Integer-Overflow-Builtins
        bool f = __builtin_usubll_overflow((unsigned long long)(t >> 64), (unsigned long long)((((u128)(((u64)t) * ni)) * ((u128)n)) >> 64), &w);
        return (w + (f ? n : 0));
    }
    u64 mr(u64 ar) const noexcept {
        // mr(ar) == ar / r (mod n)
        // R == 2**64
        // gcd(N, R) == 1
        // N * ni mod R == 1
        // 0 <= ar < N < R
        // t := floor(ar / R) - floor(((ar * ni mod R) * N) / R)
        // if t < 0 then return t + N else return t
        assert(ar < n);
        u64 v = (u64)((((u128)(ar * ni)) * ((u128)n)) >> 64);
        return ((v == 0) ? 0 : (n - v));
    }
    u64 ar(u64 a) const noexcept {
        // ar(a) == a * r (mod n)
        assert(a < n);
        return mrmul(a, r2);
    }
    u64 pow(u64 ar, u64 b) const noexcept {
        // pow(ar, b) == ((ar / r) ** b) * r (mod n)
        assert(ar < n);
        u64 tr = ((b & 1) == 0) ? r : ar;
        for (b >>= 1; b != 0; b >>= 1) {
            ar = mrmul(ar, ar);
            if ((b & 1) != 0) { tr = mrmul(tr, ar); }
        }
        return tr;
    }
};

// 64bit整数平方根(lz:ケチるループ回数*2+(0~1)、内部実装) -> (floor(sqrt(iv)), remain)
u64 isqrt64i(const u64 iv, u64* remain, const u32 lz) noexcept {
    constexpr u128 _b = ((((u128)0x0000000000000000ULL) << 64) | ((u128)0x4000000000000000ULL));
    constexpr u128 _c = ((((u128)0xfffffffffffffffeULL) << 64) | ((u128)0x0000000000000000ULL));
    constexpr u128 _d = ((((u128)0x0000000000000001ULL) << 64) | ((u128)0x0000000000000000ULL));
    constexpr u128 _e = ((((u128)0x0000000000000000ULL) << 64) | ((u128)0xffffffffffffffffULL));
    u32 n = (64 >> 1) - (lz >> 1);
    u32 s = (lz >> 1) << 1;
    u32 t = n << 1;
    u128 a = iv;
    u128 b = _b >> s;
    u128 c = _c >> s;
    u128 d = _d >> s;
    u128 e = _e >> s;
    for (u32 i = 0; i < n; ++i) {
        if (a >= b) {
            a -= b;
            b = ((b + b) & c) + (b & e) + d;
        } else {
            b = ((b + b) & c) + (b & e);
        }
        a <<= 2;
    }
    *remain = (u64)(a >> t);
    return ((u64)(b >> t));
}
// 64bit整数平方根(固定ループ回数) -> (floor(sqrt(iv)), remain)
u64 isqrt64f(const u64 iv, u64* remain) noexcept {
    return isqrt64i(iv, remain, 0);
}
// 64bit整数平方根(固定ループ回数) -> remain
u64 isqrt64f_remain(const u64 iv) noexcept {
    u64 remain;
    isqrt64i(iv, &remain, 0);
    return remain;
}
// 64bit整数平方根(可変ループ回数) -> (floor(sqrt(iv)), remain)
u64 isqrt64d(const u64 iv, u64* remain) noexcept {
    return isqrt64i(iv, remain, __builtin_clzll(iv));
}

// Jacobi symbol: ヤコビ記号
i32 jacobi(const i64 sa, u64 n) noexcept {
    u64 a; i32 j;
    if (sa >= 0) { a = (u64)(sa); j = 1; } else { a = (u64)(-sa); j = ((n & 3) == 3) ? -1 : 1; }
    while (a > 0) {
        int ba = __builtin_ctzll(a);
        a >>= ba;
        if (((n & 7) == 3 || (n & 7) == 5) && (ba & 1) != 0) { j = -j; }
        if ((a & n & 3) == 3) { j = -j; }
        u64 t = n; n = a; a = t; a %= n;
        if (a > (n >> 1)) {
            a = n - a;
            if ((n & 3) == 3) { j = -j; }
        }
    }
    return ((n == 1) ? j : 0);
}

bool primetest_u64_miller_sub(const U64Mont *const mont, u64 a) noexcept {
    if (a >= mont->n) { a %= mont->n; if (a == 0) { return true; } }
    u64 tr = mont->pow(mont->ar(a), mont->d);
    if (tr == mont->r || tr == mont->rn) { return true; }
    for (u32 j = 1; j < mont->k; ++j) {
        tr = mont->mrmul(tr, tr);
        if (tr == mont->rn) { return true; }
    }
    return false;
}

// Miller-Rabin primality test (base 2)
// strong pseudoprimes to base 2 ( https://oeis.org/A001262 ): 2047,3277,4033,4681,8321,15841,29341,42799,49141,52633,...
bool primetest_u64_miller_base2_sub(const U64Mont *const mont) noexcept {
    u64 tr = mont->pow(mont->add(mont->r, mont->r), mont->d);
    if (tr == mont->r || tr == mont->rn) { return true; }
    for (u32 j = 1; j < mont->k; ++j) {
        tr = mont->mrmul(tr, tr);
        if (tr == mont->rn) { return true; }
    }
    return false;
}

// Lucas primality test
// strong Lucas pseudoprimes ( https://oeis.org/A217255 ): 5459,5777,10877,16109,18971,22499,24569,25199,40309,58519,...
bool primetest_u64_lucas_sub(const U64Mont *const mont) noexcept {
    u64 n = mont->n;
    i64 d = 5;
    for (int i = 0; i < 64; ++i) {
        if (jacobi(d, n) == -1) { break; }
        if (i == 32 && isqrt64f_remain(n) == 0) { return false; }
        if ((i & 1) == 0) { d = -(d + 2); } else { d = 2 - d; }
        assert(i < 63);
    }
    u64 qm = mont->ar((d < 0) ? ((((u64)(1 - d)) >> 2) % n) : (n - ((((u64)(d - 1)) >> 2) % n)));
    u64 k = (n + 1) << __builtin_clzll(n + 1); // n = u64::MAX の時の挙動は?
    u64 um = mont->r;
    u64 vm = mont->r;
    u64 qn = qm;
    u64 dm = mont->ar((d < 0) ? (n - (((u64)(-d)) % n)) : (((u64)(d)) % n));
    for (k <<= 1; k > 0; k <<= 1) {
        um = mont->mrmul(um, vm);
        vm = mont->sub(mont->mrmul(vm, vm), mont->add(qn, qn));
        qn = mont->mrmul(qn, qn);
        if ((k >> 63) != 0) {
            u64 uu = mont->add(um, vm);
            uu = mont->div2(uu);
            vm = mont->add(mont->mrmul(dm, um), vm);
            vm = mont->div2(vm);
            um = uu;
            qn = mont->mrmul(qn, qm);
        }
    }
    if (um == 0 || vm == 0) { return true; }
    u64 x = ((n + 1) & (~n));  // n = u64::MAX の時の挙動は?
    for (x >>= 1; x > 0; x >>= 1) {
        um = mont->mrmul(um, vm);
        vm = mont->sub(mont->mrmul(vm, vm), mont->add(qn, qn));
        if (vm == 0) { return true; }
        qn = mont->mrmul(qn, qn);
    }
    return false;
}

// Baillie–PSW primarity test
bool primetest_u64_bpsw_sub(const U64Mont *const mont) noexcept {
    return primetest_u64_miller_base2_sub(mont) && primetest_u64_lucas_sub(mont);
}

// Baillie–PSW primarity test
bool primetest_u64_bpsw(u64 n) noexcept {
    if (n == 2) { return true; }
    if (n == 1 || (n & 1) == 0) { return false; }
    U64Mont mont(n);
    return primetest_u64_bpsw_sub(&mont);
}

U64Mont u64mont_new(u64 n) noexcept { return U64Mont(n); }
u64 u64mont_add(const U64Mont *const mont, u64 ar, u64 br) noexcept { return mont->add(ar, br); }
u64 u64mont_sub(const U64Mont *const mont, u64 ar, u64 br) noexcept { return mont->sub(ar, br); }
u64 u64mont_div2(const U64Mont *const mont, u64 ar) noexcept { return mont->div2(ar); }
u64 u64mont_mrmul(const U64Mont *const mont, u64 ar, u64 br) noexcept { return mont->mrmul(ar, br); }
u64 u64mont_mr(const U64Mont *const mont, u64 a) noexcept { return mont->mr(a); }
u64 u64mont_ar(const U64Mont *const mont, u64 a) noexcept { return mont->ar(a); }
u64 u64mont_pow(const U64Mont *const mont, u64 ar, u64 b) noexcept { return mont->pow(ar, b); }
int ctzll(unsigned long long v) noexcept { return __builtin_ctzll(v); }
int clzll(unsigned long long v) noexcept { return __builtin_clzll(v); }

const u64 bases[] = {2,325,9375,28178,450775,9780504,1795265022};
bool primetest_u64_miller_base7_sub(const U64Mont *const mont) noexcept {
    if (!primetest_u64_miller_base2_sub(mont)) { return false; }
    for (const auto& base : bases) {
        if (!primetest_u64_miller_sub(mont, base)) { return false; }
    }
    return true;
}

bool primetest_u64_miller_base7(const u64 n) noexcept {
    if (n == 2) { return true; }
    if (n < 2 || (n & 1) == 0) { return false; }
    U64Mont mont(n);
    return primetest_u64_miller_base7_sub(&mont);
}

int main(int, char**) {
    struct timespec time_monotonic_start, time_process_start, time_monotonic_end, time_process_end;
    clock_gettime(CLOCK_MONOTONIC, &time_monotonic_start);
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time_process_start);
    int n;
    unsigned long long x;
    scanf("%d", &n);
    for(int i = 0; i < n; ++i) {
        scanf("%llu", &x);
        //printf("%llu %d\n", x, primetest_u64_miller_base7((u64)x) ? 1 : 0);
        printf("%llu %d\n", x, primetest_u64_bpsw((u64)x) ? 1 : 0);
    }
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &time_process_end);
    clock_gettime(CLOCK_MONOTONIC, &time_monotonic_end);
    f64 d_sec_monotonic =
        (f64)(time_monotonic_end.tv_sec - time_monotonic_start.tv_sec) +
        (f64)(time_monotonic_end.tv_nsec - time_monotonic_start.tv_nsec) / (1000 * 1000 * 1000);
    f64 d_sec_process =
        (f64)(time_process_end.tv_sec - time_process_start.tv_sec) +
        (f64)(time_process_end.tv_nsec - time_process_start.tv_nsec) / (1000 * 1000 * 1000);
    fprintf(stderr, "time_monotonic:%f, time_process:%f\n", d_sec_monotonic, d_sec_process);
}
0