結果

問題 No.3030 ミラー・ラビン素数判定法のテスト
ユーザー 👑 MizarMizar
提出日時 2022-08-31 10:21:04
言語 Rust
(1.77.0 + proconio)
結果
AC  
実行時間 16 ms / 9,973 ms
コード長 9,291 bytes
コンパイル時間 12,301 ms
コンパイル使用メモリ 401,452 KB
実行使用メモリ 6,820 KB
最終ジャッジ日時 2024-11-17 00:02:56
合計ジャッジ時間 13,456 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
6,816 KB
testcase_01 AC 1 ms
6,820 KB
testcase_02 AC 1 ms
6,816 KB
testcase_03 AC 1 ms
6,816 KB
testcase_04 AC 11 ms
6,820 KB
testcase_05 AC 12 ms
6,816 KB
testcase_06 AC 10 ms
6,816 KB
testcase_07 AC 10 ms
6,816 KB
testcase_08 AC 10 ms
6,816 KB
testcase_09 AC 16 ms
6,816 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

// -*- coding:utf-8-unix -*-

// モンゴメリ剰余乗算 (Montgomery modular multiplication)
pub trait MontTrait<T> {
    fn new(n: T) -> Self;
    fn add(&self, a: T, b: T) -> T;
    fn sub(&self, a: T, b: T) -> T;
    fn div2(&self, ar: T) -> T;
    fn mrmul(&self, ar: T, br: T) -> T;
    fn mr(&self, ar: T) -> T;
    fn ar(&self, a: T) -> T;
    fn pow(&self, ar: T, b: T) -> T;
}
pub struct Mont<T, BitCountType=u32> {
    n: T, // n is odd, and n > 2
    ni: T, // n * ni == 1 (mod 2**64)
    nh: T, // == (n + 1) / 2
    r: T, // == 2**64 (mod n)
    rn: T, // == -(2**64) (mod n)
    r2: T, // == 2**128 (mod n)
    d: T, // == (n - 1) >> (n - 1).trailing_zeros()
    k: BitCountType, // == (n - 1).trailing_zeros()
}
impl MontTrait<u64> for Mont<u64> {
    #[inline]
    fn new(n: u64) -> Self {
        debug_assert_eq!(n & 1, 1);
        // // n is odd number, n = 2*k+1, n >= 1, n < 2**64, k is non-negative integer, k >= 0, k < 2**63
        // ni0 := n; // = 2*k+1 = (1+(2**2)*((k*(k+1))**1))/(2*k+1)
        let mut ni = n;
        // ni1 := ni0 * (2 - (n * ni0)); // = (1-(2**4)*((k*(k+1))**2))/(2*k+1)
        // ni2 := ni1 * (2 - (n * ni1)); // = (1-(2**8)*((k*(k+1))**4))/(2*k+1)
        // ni3 := ni2 * (2 - (n * ni2)); // = (1-(2**16)*((k*(k+1))**8))/(2*k+1)
        // ni4 := ni3 * (2 - (n * ni3)); // = (1-(2**32)*((k*(k+1))**16))/(2*k+1)
        // ni5 := ni4 * (2 - (n * ni4)); // = (1-(2**64)*((k*(k+1))**32))/(2*k+1)
        // // (n * ni5) mod 2**64 = ((2*k+1) * ni5) mod 2**64 = 1 mod 2**64
        for _ in 0..5 {
            ni = ni.wrapping_mul(2u64.wrapping_sub(n.wrapping_mul(ni)));
        }
        debug_assert_eq!(n.wrapping_mul(ni), 1); // n * ni == 1 (mod 2**64)
        let nh = (n >> 1) + 1;
        let r = n.wrapping_neg() % n; // == 2**64 (mod n)
        let rn = n - r;
        let r2 = ((n as u128).wrapping_neg() % (n as u128)) as u64; // == 2**128 (mod n)
        let mut d = n - 1;
        let k = d.trailing_zeros();
        d >>= k;
        debug_assert_eq!(Self { n, ni, nh, r, rn, r2, d, k }.mr(r), 1); // r / r == 1 (mod n)
        debug_assert_eq!(Self { n, ni, nh, r, rn, r2, d, k }.mrmul(1, r2), r); // r2 / r == r (mod n)
        Self { n, ni, nh, r, rn, r2, d, k }
    }
    #[inline]
    fn add(&self, a: u64, b: u64) -> u64 {
        // == a + b (mod n)
        debug_assert!(a < self.n);
        debug_assert!(b < self.n);
        let (t, fa) = a.overflowing_add(b);
        let (u, fs) = t.overflowing_sub(self.n);
        if fa { u } else { if fs { t } else { u } }
    }
    #[inline]
    fn sub(&self, a: u64, b: u64) -> u64 {
        // == a - b (mod n)
        debug_assert!(a < self.n);
        debug_assert!(b < self.n);
        let (t, f) = a.overflowing_sub(b);
        if f { t.wrapping_add(self.n) } else { t }
    }
    #[inline]
    fn div2(&self, ar: u64) -> u64 {
        // == ar / 2 (mod n)
        debug_assert!(ar < self.n);
        let t = ar >> 1;
        if (ar & 1) == 0 { t } else { t + self.nh }
    }
    #[inline]
    fn mrmul(&self, ar: u64, br: u64) -> u64 {
        // == (ar * br) / r (mod n)
        debug_assert!(ar < self.n);
        debug_assert!(br < self.n);
        let (n, ni) = (self.n, self.ni);
        let t: u128 = (ar as u128) * (br as u128);
        let (t, f) = ((t >> 64) as u64).overflowing_sub((((((t as u64).wrapping_mul(ni)) as u128) * (n as u128)) >> 64) as u64);
        if f { t.wrapping_add(n) } else { t }
    }
    #[inline]
    fn mr(&self, ar: u64) -> u64 {
        // == ar / r (mod n)
        debug_assert!(ar < self.n);
        let (n, ni) = (self.n, self.ni);
        let (t, f) = (((((ar.wrapping_mul(ni)) as u128) * (n as u128)) >> 64) as u64).overflowing_neg();
        if f { t.wrapping_add(n) } else { t }
    }
    #[inline]
    fn ar(&self, a: u64) -> u64 {
        // == a * r (mod n)
        debug_assert!(a < self.n);
        self.mrmul(a, self.r2)
    }
    #[inline]
    fn pow(&self, mut ar: u64, mut b: u64) -> u64 {
        // == ((ar / r) ** b) * r (mod n)
        debug_assert!(ar < self.n);
        let mut t = if (b & 1) == 0 { self.r } else { ar };
        loop {
            b >>= 1;
            if b == 0 { return t; }
            ar = self.mrmul(ar, ar);
            if (b & 1) != 0 { t = self.mrmul(t, ar); }
        }
    }
}

// 64bit整数平方根(固定ループ回数) -> (floor(sqrt(iv)), remain)
#[inline]
#[allow(unused)]
fn isqrt64f(iv: u64) -> (u64, u64) { isqrt64i(iv, 0) }
// 64bit整数平方根(可変ループ回数) -> (floor(sqrt(iv)), remain)
#[inline]
#[allow(unused)]
fn isqrt64d(iv: u64) -> (u64, u64) { isqrt64i(iv, iv.leading_zeros()) }
// 64bit整数平方根(lz:ケチるループ回数*2+(0~1)、内部実装) -> (floor(sqrt(iv)), remain)
#[inline]
fn isqrt64i(iv: u64, lz: u32) -> (u64, u64) {
    let n = (64 >> 1) - (lz >> 1);
    let s = (lz >> 1) << 1;
    let t = n << 1;
    let (mut a, mut b, c, d, e) = (
        iv as u128,
        0x0000_0000_0000_0000_4000_0000_0000_0000 >> s,
        0xffff_ffff_ffff_fffe_0000_0000_0000_0000 >> s,
        0x0000_0000_0000_0001_0000_0000_0000_0000 >> s,
        0x0000_0000_0000_0000_ffff_ffff_ffff_ffff >> s,
    );
    for _ in 0..n {
        let f = ((b + b) & c) + (b & e);
        if a >= b {
            a -= b;
            b = f + d;
        } else {
            b = f;
        }
        a <<= 2;
    }
    ((b >> t) as u64, (a >> t) as u64)
}

// Jacobi symbol: ヤコビ記号
#[inline]
fn jacobi(a: i64, mut n: u64) -> i32 {
    let (mut a, mut j): (u64, i32) = if a >= 0 { (a as u64, 1) } else if (n & 3) == 3 { ((-a) as u64, -1) } else { ((-a) as u64, 1) };
    while a > 0 {
        let ba = a.trailing_zeros();
        a >>= ba;
        if ((n & 7) == 3 || (n & 7) == 5) && (ba & 1) != 0 { j = -j; }
        if (a & n & 3) == 3 { j = -j; }
        let t = a; a = n; n = t; a %= n;
        if a > (n >> 1) {
            a = n - a;
            if (n & 3) == 3 { j = -j; }
        }
    }
    if n == 1 { j } else { 0 }
}

#[inline]
fn primetest_base2(mont: &Mont<u64>) -> bool {
    // Mirrer-Rabin primality test (base 2)
    // strong pseudoprimes to base 2 ( https://oeis.org/A001262 ): 2047,3277,4033,4681,8321,15841,29341,42799,49141,52633,...
    let (r, rn, d, k) = (mont.r, mont.rn, mont.d, mont.k);
    let mut br = mont.pow(mont.add(mont.r, mont.r), d);
    if br == r || br == rn { return true; }
    for _ in 1..k {
        br = mont.mrmul(br, br);
        if br == rn { return true; }
    }
    false
}

#[inline]
fn primetest_lucas(mont: &Mont<u64>) -> bool {
    // Lucas primality test
    // strong Lucas pseudoprimes ( https://oeis.org/A217255 ): 5459,5777,10877,16109,18971,22499,24569,25199,40309,58519,...
    let n = mont.n;
    let mut d: i64 = 5;
    for i in 0.. {
        debug_assert!(i < 64);
        match jacobi(d, n) {
            -1 => break,
            0 => if ((d.abs()) as u64) < n { return false; },
            _ => {},
        }
        if i == 32 && isqrt64f(n).1 == 0 { return false; }
        if (i & 1) == 1 { d = 2 - d; } else { d = -(d + 2); }
    }
    let qm = mont.ar(if d < 0 {((1 - d) as u64) / 4 % n} else {n - ((d - 1) as u64) / 4 % n});
    let mut k = (n + 1) << (n + 1).leading_zeros();
    let mut um = mont.r;
    let mut vm = mont.r;
    let mut qn = qm;
    let dm: u64 = mont.ar(if d < 0 { n - (((-d) as u64) % n) } else { (d as u64) % n });
    k <<= 1;
    while k > 0 {
        um = mont.mrmul(um, vm);
        vm = mont.sub(mont.mrmul(vm, vm), mont.add(qn, qn));
        qn = mont.mrmul(qn, qn);
        if (k >> 63) != 0 {
            let uu = mont.div2(mont.add(um, vm));
            vm = mont.div2(mont.add(mont.mrmul(dm, um), vm));
            um = uu;
            qn = mont.mrmul(qn, qm);
        }
        k <<= 1;
    }
    if um == 0 || vm == 0 {
        return true;
    }
    let mut x = (n + 1) & (!n);
    x >>= 1;
    while x > 0 {
        um = mont.mrmul(um, vm);
        vm = mont.sub(mont.mrmul(vm, vm), mont.add(qn, qn));
        if vm == 0 {
            return true;
        }
        qn = mont.mrmul(qn, qn);
        x >>= 1;
    }
    false
}

// Baillie–PSW primarity test
pub fn primetest_bpsw(n: u64) -> bool {
    if n == 2 { return true; }
    if n == 1 || (n & 1) == 0 { return false; }
    let mont = Mont::<u64>::new(n);
    // Mirrer-Rabin primality test (base 2)
    // strong pseudoprimes to base 2 ( https://oeis.org/A001262 ): 2047,3277,4033,4681,8321,15841,29341,42799,49141,52633,...
    primetest_base2(&mont) &&
    // Lucas primality test
    // strong Lucas pseudoprimes ( https://oeis.org/A217255 ): 5459,5777,10877,16109,18971,22499,24569,25199,40309,58519,...
    primetest_lucas(&mont)
}

fn main() {
    use std::io::{BufRead,Write};
    let start_time = std::time::Instant::now();
    let input = std::io::stdin();
    let mut out = std::io::stdout();
    let mut lines = std::io::BufReader::new(input.lock()).lines();
    let n = lines.next().unwrap().unwrap().parse::<usize>().unwrap();
    let mut s = String::with_capacity(n * 22);
    for _ in 0..n {
        let xs = lines.next().unwrap().unwrap();
        let r = if primetest_bpsw(xs.parse::<u64>().unwrap()) { " 1\n" } else { " 0\n" };
        s += &xs;
        s += r;
    }
    out.write_all(s.as_bytes()).unwrap();
    eprint!("{}us\n", start_time.elapsed().as_micros());
}
0