結果

問題 No.3030 ミラー・ラビン素数判定法のテスト
ユーザー 👑 MizarMizar
提出日時 2022-08-27 18:13:11
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 29 ms / 9,973 ms
コード長 6,135 bytes
コンパイル時間 1,218 ms
コンパイル使用メモリ 140,788 KB
実行使用メモリ 5,376 KB
最終ジャッジ日時 2024-04-28 09:59:33
合計ジャッジ時間 1,830 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 1 ms
5,376 KB
testcase_02 AC 2 ms
5,376 KB
testcase_03 AC 2 ms
5,376 KB
testcase_04 AC 18 ms
5,376 KB
testcase_05 AC 17 ms
5,376 KB
testcase_06 AC 10 ms
5,376 KB
testcase_07 AC 10 ms
5,376 KB
testcase_08 AC 11 ms
5,376 KB
testcase_09 AC 29 ms
5,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#pragma GCC optimize("O3")
//#pragma GCC optimize("unroll-loops")
#ifndef NDEBUG
#define NDEBUG
#endif
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <ctime>
#ifdef _MSC_VER
#include <intrin.h>
#else
#include <x86intrin.h>
#endif

class U64Mont {
public:
    const uint64_t n; // == n
    const uint64_t ni; // n * ni == 1 (mod 2**64)
    const uint64_t n1; // == n - 1
    const uint64_t nh; // == (n + 1) / 2
    const uint64_t r; // == 2**64 (mod n)
    const uint64_t n1r; // == -(2**64) (mod n)
    const uint64_t r2; // == 2**128 (mod n)
    const uint64_t d; // == n1 >> k // n == 2**k * d + 1
    const uint32_t k; // == trailing_zeros(n1)
    U64Mont(uint64_t n, uint64_t ni, uint64_t n1, uint64_t nh, uint64_t r, uint64_t n1r, uint64_t r2, uint64_t d, uint32_t k)
        : n(n), ni(ni), n1(n1), nh(nh), r(r), n1r(n1r), r2(r2), d(d), k(k) {}
    static U64Mont build(uint64_t n) {
        assert(n & 1 == 1);
        // // n is odd number, n = 2*k+1, n >= 1, n < 2**64, k is non-negative integer, k >= 0, k < 2**63
        // ni0 := n; // = 2*k+1 = (1+(2**2)*((k*(k+1))**1))/(2*k+1)
        uint64_t ni = n;
        // ni1 := ni0 * (2 - (n * ni0)); // = (1-(2**4)*((k*(k+1))**2))/(2*k+1)
        // ni2 := ni1 * (2 - (n * ni1)); // = (1-(2**8)*((k*(k+1))**4))/(2*k+1)
        // ni3 := ni2 * (2 - (n * ni2)); // = (1-(2**16)*((k*(k+1))**8))/(2*k+1)
        // ni4 := ni3 * (2 - (n * ni3)); // = (1-(2**32)*((k*(k+1))**16))/(2*k+1)
        // ni5 := ni4 * (2 - (n * ni4)); // = (1-(2**64)*((k*(k+1))**32))/(2*k+1)
        // // (n * ni5) mod 2**64 = ((2*k+1) * ni5) mod 2**64 = 1 mod 2**64
        for (int i = 0; i < 5; ++i) {
            ni = ni * (2 - n * ni);
        }
        assert(n * ni == 1); // n * ni == 1 (mod 2**64)
        uint64_t n1 = n - 1; // == n - 1
        uint64_t nh = (n >> 1) + 1; // == (n + 1) / 2
        uint64_t r = (-n) % n; // == 2**64 (mod n)
        uint64_t n1r = n - r; // == -(2**64) (mod n)
        uint64_t r2 = (uint64_t)((-((__uint128_t)n)) % ((__uint128_t)n)); // == 2**128 (mod n)
        // n == 2**k * d + 1
        uint32_t k = __builtin_ctzll(n1); // == trailing_zeros(n1)
        uint64_t d = n1 >> k;
        U64Mont tobj(n, ni, n1, nh, r, n1r, r2, d, k);
        return tobj;
    }
    uint64_t add(uint64_t a, uint64_t b) {
        // add(a, b) == a + b (mod n)
        assert(a < n);
        assert(b < n);
        unsigned long long t, u;
        unsigned char f1 = _addcarry_u64(0, a, b, &t);
        unsigned char f2 = _subborrow_u64(0, t, f1 ? n : 0, &u);
        return f2 ? t : u;
    }
    uint64_t sub(uint64_t a, uint64_t b) {
        // sub(a, b) == a - b (mod n)
        assert(a < n);
        assert(b < n);
        unsigned long long t;
        unsigned char f = _subborrow_u64(0, a, b, &t);
        return t + (f ? n : 0);
    }
    uint64_t div2(uint64_t ar) {
        // div2(ar) == ar / 2 (mod n)
        assert(ar < n);
        if ((ar & 1) == 0) {
            return (ar >> 1);
        } else {
            return (ar >> 1) + nh;
        }
    }
    uint64_t mrmul(uint64_t ar, uint64_t br) {
        // mrmul(ar, br) == (ar * br) / r (mod n)
        // R == 2**64
        // gcd(N, R) == 1
        // N * ni mod R == 1
        // 0 <= ar < N < R
        // 0 <= br < N < R
        // T := ar * br
        // t := floor(T / R) - floor(((T * ni mod R) * N) / R)
        // if t < 0 then return t + N else return t
        assert(ar < n);
        assert(br < n);
        __uint128_t t = ((__uint128_t)ar) * ((__uint128_t)br);
        uint64_t u = (uint64_t)(t >> 64);
        uint64_t v = (uint64_t)((((__uint128_t)(((uint64_t)t) * ni)) * ((__uint128_t)n)) >> 64);
        unsigned long long w;
        unsigned char f = _subborrow_u64(0, u, v, &w);
        return w + (f ? n : 0);
    }
    uint64_t mr(uint64_t ar) {
        // mr(ar) == ar / r (mod n)
        // R == 2**64
        // gcd(N, R) == 1
        // N * ni mod R == 1
        // 0 <= ar < N < R
        // t := floor(ar / R) - floor(((ar * ni mod R) * N) / R)
        // if t < 0 then return t + N else return t
        assert(ar < p_this->n);
        uint64_t v = (uint64_t)((((__uint128_t)(ar * ni)) * ((__uint128_t)n)) >> 64);
        return v == 0 ? 0 : n - v;
    }
    uint64_t ar(uint64_t ar) {
        // ar(a) == a * r (mod n)
        assert(ar < n);
        return mrmul(ar, r2);
    }
    uint64_t pow(uint64_t ar, uint64_t b) {
        // pow(ar, b) == ((ar / r) ** b) * r (mod n)
        assert(ar < n);
        uint64_t t = ((b & 1) == 0) ? r : ar;
        b >>= 1;
        while (b != 0) {
            ar = mrmul(ar, ar);
            if ((b & 1) != 0) { t = mrmul(t, ar); }
            b >>= 1;
        }
        return t;
    }
};

const uint64_t bases[] = {2,325,9375,28178,450775,9780504,1795265022};
int miller_rabin(uint64_t n) {
    // Deterministic variants of the Miller-Rabin primality test
    // http://miller-rabin.appspot.com/
    if (n == 2) { return 1; }
    if (n < 2 || (n & 1) == 0) { return 0; }
    U64Mont mont = U64Mont::build(n);
    for (int i = 0; i < 7; ++i) {
        uint64_t a = bases[i];
        assert(a > 0);
        if (a >= n) { a %= n; if (a == 0) { continue; } }
        uint64_t ar = mont.ar(a);
        uint64_t tr = mont.pow(ar, mont.d);
        if (tr == mont.r || tr == mont.n1r) { continue; }
        for (int j = 1; j < mont.k; ++j) { tr = mont.mrmul(tr, tr); if (tr == mont.n1r) { goto cont; } }
        return 0;
        cont: continue;
    }
    return 1;
}

int main(int argc, char *argv[]) {
    struct timespec start_time, end_time;
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &start_time);
    int n;
    scanf("%d", &n);
    for(int i = 0; i < n; ++i) {
        uint64_t x;
        scanf("%lld", &x);
        printf("%lld %d\n", x, miller_rabin(x));
    }
    clock_gettime(CLOCK_PROCESS_CPUTIME_ID, &end_time);
    int sec = end_time.tv_sec - start_time.tv_sec;
    int nsec = end_time.tv_nsec - start_time.tv_nsec;
    double d_sec = (double)sec + (double)nsec / (1000 * 1000 * 1000);
    fprintf(stderr, "time:%f\n", d_sec);
}
0