結果

問題 No.3030 ミラー・ラビン素数判定法のテスト
ユーザー sakikuroesakikuroe
提出日時 2024-03-05 13:21:10
言語 Rust
(1.77.0)
結果
AC  
実行時間 64 ms / 9,973 ms
コード長 5,651 bytes
コンパイル時間 732 ms
コンパイル使用メモリ 177,216 KB
実行使用メモリ 6,676 KB
最終ジャッジ日時 2024-03-05 13:21:12
合計ジャッジ時間 2,617 ms
ジャッジサーバーID
(参考情報)
judge15 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
6,676 KB
testcase_01 AC 0 ms
6,676 KB
testcase_02 AC 1 ms
6,676 KB
testcase_03 AC 1 ms
6,676 KB
testcase_04 AC 49 ms
6,676 KB
testcase_05 AC 49 ms
6,676 KB
testcase_06 AC 40 ms
6,676 KB
testcase_07 AC 40 ms
6,676 KB
testcase_08 AC 39 ms
6,676 KB
testcase_09 AC 64 ms
6,676 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

// use proconio::input;
use std::ops;

#[allow(non_snake_case)]
/// Montgomery 乗算のための構造体
pub struct Montgomery {
    pub N: u64,   // mod, odd
    N_prime: u64, // N * N_prime == -1 mod R
    R2: u64,      // R^{2} mod N
}

#[allow(non_snake_case)]
impl Montgomery {
    const LOG_R: u64 = 64; // R = 2**log_R
    const R: u128 = 1_u128 << Montgomery::LOG_R;

    pub fn new(N: u64) -> Self {
        assert!(N % 2 == 1);

        fn extended_gcd(a: i128, b: i128) -> (i128, i128) {
            let (mut s, mut xs, mut ys) = (a, 1, 0);
            let (mut t, mut xt, mut yt) = (b, 0, 1);

            while s % t != 0 {
                let q = s / t;
                let (u, xu, yu) = (s - q * t, xs - q * xt, ys - q * yt);
                (s, t) = (t, u);
                (xs, ys, xt, yt) = (xt, yt, xu, yu);
            }

            if t < 0 {
                (xt, yt) = (-xt, -yt);
            }

            (xt, yt)
        }

        let N_prime = {
            let (_, b) = extended_gcd(Montgomery::R as i128, N as i128);
            if b <= 0 {
                (-b) as u64
            } else {
                (-b + Montgomery::R as i128) as u64
            }
        };

        let R2 = ((Montgomery::R % N as u128) * (Montgomery::R % N as u128) % N as u128) as u64;

        Montgomery { N, N_prime, R2 }
    }

    /// Montgomery reduction
    ///
    /// Returns:
    ///     T * R^{-1} mod N
    fn montgomery_reduction(&self, T: u128) -> u64 {
        let t = {
            ((T + (((T as u64).wrapping_mul(self.N_prime)) as u128 * self.N as u128))
                >> Montgomery::LOG_R) as u64
        };

        if t >= self.N {
            t - self.N
        } else {
            t
        }
    }

    /// Returns:
    ///     T * R mod N
    fn montgomery(&self, T: u64) -> u64 {
        self.montgomery_reduction(T as u128 * self.R2 as u128)
    }

    /// Returns:
    ///     (A + B) * R mod N
    ///
    /// Constraints:
    ///     - ar < N
    ///     - br < N
    fn montgomery_add(&self, ar: u64, br: u64) -> u64 {
        let t = ar + br;
        if t < self.N {
            t
        } else {
            t - self.N
        }
    }

    /// Returns:
    ///     A * B * R mod N
    ///
    /// Constraints:
    ///     - ar < N
    ///     - br < N
    fn montgomery_mul(&self, ar: u64, br: u64) -> u64 {
        self.montgomery_reduction(ar as u128 * br as u128)
    }
}

#[derive(Clone)]
pub struct MontgomeryModInt<'a> {
    ar: u64,
    pub montgomery: &'a Montgomery,
}

impl<'a> MontgomeryModInt<'a> {
    pub fn new(a: u64, montgomery: &'a Montgomery) -> MontgomeryModInt<'a> {
        let ar = montgomery.montgomery(a);
        MontgomeryModInt { ar, montgomery }
    }

    pub fn val(&self) -> u64 {
        self.montgomery.montgomery_reduction(self.ar as u128)
    }

    pub fn pow(&self, mut n: usize) -> MontgomeryModInt<'a> {
        let mut res = MontgomeryModInt::new(1, self.montgomery);
        let mut x = self.clone();
        while n > 0 {
            if n % 2 == 1 {
                res = res * x.clone();
            }
            x = x.clone() * x;
            n /= 2;
        }

        res
    }
}

impl<'a> ops::Add for MontgomeryModInt<'a> {
    type Output = MontgomeryModInt<'a>;
    fn add(self, other: Self) -> Self {
        MontgomeryModInt {
            ar: self.montgomery.montgomery_add(self.ar, other.ar),
            montgomery: self.montgomery,
        }
    }
}

impl<'a> ops::Mul for MontgomeryModInt<'a> {
    type Output = MontgomeryModInt<'a>;
    fn mul(self, other: Self) -> Self {
        MontgomeryModInt {
            ar: self.montgomery.montgomery_mul(self.ar, other.ar),
            montgomery: self.montgomery,
        }
    }
}

/// Returns:
///     if n is prime number:
///         true
///     else:
///         false
///
/// Algorithm:
///     Miller-Rabin
///
/// References:
///     - [Deterministic variants of the Miller-Rabin primality test. Miller-Rabin SPRP bases records](https://miller-rabin.appspot.com/)
///     - [64bit数の素数判定](https://zenn.dev/mizar/articles/791698ea860581)
pub fn is_prime(n: u64) -> bool {
    if n == 0 || n == 1 {
        return false;
    }

    if n == 2 {
        return true;
    }

    if n % 2 == 0 {
        return false;
    }

    // n - 1 == 2**s * d (d: odd)
    let s = (n - 1).trailing_zeros();
    let d = (n - 1) >> s;

    let mont = Montgomery::new(n);

    let maybe_prime = |a| {
        let a = a % n;

        if a == 0 {
            return true;
        }

        let a = MontgomeryModInt::new(a, &mont);

        let mut ad = a.pow(d as usize);

        if ad.val() == 1 || ad.val() == n - 1 {
            return true;
        }

        for _ in 1..s {
            ad = ad.clone() * ad;
            if ad.val() == n - 1 {
                return true;
            }
        }

        false
    };

    [2, 325, 9375, 28178, 450775, 9780504, 1795265022]
        .into_iter()
        .all(maybe_prime)
}

fn main() {
    // input! {
    //     n: usize,
    // }
    let n;
    {
        let mut s = String::new();
        std::io::stdin().read_line(&mut s).unwrap();
        let mut ws = s.split_whitespace();
        n = ws.next().unwrap().parse::<usize>().unwrap();
    }

    for _ in 0..n {
        // input! {
        //     x: usize,
        // }
        let x;
        {
            let mut s = String::new();
            std::io::stdin().read_line(&mut s).unwrap();
            let mut ws = s.split_whitespace();
            x = ws.next().unwrap().parse::<usize>().unwrap();
        }

        println!("{} {}", x, if is_prime(x as u64) { 1 } else { 0 });
    }
}
0