結果
| 問題 |
No.950 行列累乗
|
| コンテスト | |
| ユーザー |
akakimidori
|
| 提出日時 | 2019-12-14 03:35:40 |
| 言語 | Rust (1.83.0 + proconio) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 5,856 bytes |
| コンパイル時間 | 13,500 ms |
| コンパイル使用メモリ | 400,996 KB |
| 実行使用メモリ | 11,116 KB |
| 最終ジャッジ日時 | 2024-06-28 02:18:54 |
| 合計ジャッジ時間 | 17,927 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 52 WA * 5 |
ソースコード
use std::io::Read;
fn mod_pow(r: u64, mut n: u64, p: u64) -> u64 {
assert!(r < p);
let mut t = 1;
let mut s = r;
while n > 0 {
if n & 1 == 1 {
t = t * s % p;
}
s = s * s % p;
n >>= 1;
}
t
}
fn inv(a: u64, p: u64) -> u64 {
assert!(0 < a && a < p);
mod_pow(a, p - 2, p)
}
fn log(x: u64, a: u64, p: u64) -> u64 {
assert!(0 < x && x < p && a < p);
let sq = (p as f64).sqrt() as u64 + 1;
let mut map = std::collections::HashMap::new();
for i in (0..sq).rev() {
map.insert(a * mod_pow(inv(x, p), i, p) % p, i);
}
let mut k = 1;
let x = mod_pow(x, sq, p);
let mut ans = p;
for i in 0..sq {
if let Some(&v) = map.get(&k) {
let v = (i * sq + v) % (p - 1);
ans = std::cmp::min(ans, v);
}
k = k * x % p;
}
ans
}
type M = [[u64; 2]; 2];
fn matmul(a: &M, b: &M, p: u64) -> M {
let mut c = [[0; 2]; 2];
for (c, a) in c.iter_mut().zip(a.iter()) {
for (a, b) in a.iter().zip(b.iter()) {
for (c, b) in c.iter_mut().zip(b.iter()) {
*c = (*c + *a * *b) % p;
}
}
}
c
}
fn matmul_s(a: &M, x: u64, p: u64) -> M {
let mut c = [[0; 2]; 2];
for (c, a) in c.iter_mut().zip(a.iter()) {
for (c, a) in c.iter_mut().zip(a.iter()) {
*c = *a * x % p;
}
}
c
}
fn mat_pow(a: &M, mut n: u64, p: u64) -> M {
let mut t = [[0; 2]; 2];
t[0][0] = 1;
t[1][1] = 1;
let mut s = a.clone();
while n > 0 {
if n & 1 == 1 {
t = matmul(&t, &s, p);
}
s = matmul(&s, &s, p);
n >>= 1;
}
t
}
/*
* https://twitter.com/maspy_stars/status/1205499459993362432
* を参考に
* A^2 = tr(A)A - det(A)I なので
* det(A) == 0 のときはA^n = (tr(A))^(n - 1)A なので適当に離散対数を解く
* det(A) != 0 のときはA^n = B から(A^r)(A^qk) = B, det(A^r) = det(B), det(A^q) = 1
* となるようなq,rを求める。
* A をA^q, B をBA^(-r)と置き換えて
* A^n = B , det(A) = det(B) = 1 が解ければ良い
* F_p云々はよくわからないがA^n = x * A + y * I とかけることと
* detA = 1 の条件から x^2 + tr(A)xy + y^2 = 1 (mod p)
* という関係式が立ってxに対してyが高々2通りになることから2p程度で抑えられる?
* */
fn run() {
let mut s = String::new();
std::io::stdin().read_to_string(&mut s).unwrap();
let mut it = s.trim().split_whitespace();
let p: u64 = it.next().unwrap().parse().unwrap();
let mut a: M = [[0; 2]; 2];
let mut b: M = [[0; 2]; 2];
for i in 0..2 {
for j in 0..2 {
a[i][j] = it.next().unwrap().parse().unwrap();
}
}
for i in 0..2 {
for j in 0..2 {
b[i][j] = it.next().unwrap().parse().unwrap();
}
}
if a == b {
println!("1");
return;
}
let det_a = (a[0][0] * a[1][1] % p + p - a[0][1] * a[1][0] % p) % p;
if det_a == 0 {
let x = (a[0][0] + a[1][1]) % p;
if x == 0 {
if b == [[0; 2]; 2] {
println!("2");
} else {
println!("-1");
}
return;
}
let mut y = p;
for i in 0..2 {
for j in 0..2 {
if a[i][j] != 0 {
let m = log(x, b[i][j] * inv(a[i][j], p) % p, p);
y = std::cmp::min(y, m);
}
}
}
if y < p && b == matmul_s(&a, mod_pow(x, y, p), p) {
println!("{}", y + 1);
} else {
println!("-1");
}
} else {
let det_b = (b[0][0] * b[1][1] % p + p - b[0][1] * b[1][0] % p) % p;
if det_b == 0 {
println!("-1");
return;
}
let v = log(det_a, det_b, p);
if v >= p {
println!("-1");
return;
}
let ia = matmul_s(&[[a[1][1], (p - a[0][1]) % p], [(p - a[1][0]) % p, a[0][0]]], inv(det_a, p), p);
let b = matmul(&b, &mat_pow(&ia, v, p), p);
assert!(b[0][0] * b[1][1] % p == (b[0][1] * b[1][0] + 1) % p);
let mut phi = p - 1;
let mut factor = vec![];
for k in 2.. {
if k * k > phi {
if phi > 1 {
factor.push(phi);
}
break;
}
if phi % k == 0 {
factor.push(k);
while phi % k == 0 {
phi /= k;
}
}
}
let mut phi = p - 1;
for &f in &factor {
while phi % f == 0 && mod_pow(det_a, phi / f, p) == 1 {
phi /= f;
}
}
let phi = phi;
let a = mat_pow(&a, phi, p);
assert!(a[0][0] * a[1][1] % p == (a[0][1] * a[1][0] + 1) % p);
let ia = [[a[1][1], (p - a[0][1]) % p], [(p - a[1][0]) % p, a[0][0]]];
assert!(matmul(&a, &ia, p) == [[1, 0], [0, 1]]);
let mut map = std::collections::HashMap::new();
let sq = ((2 * p) as f64).sqrt() as u64 + 1;
let mut m = matmul(&b, &mat_pow(&ia, sq - 1, p), p);
for i in (0..sq).rev() {
map.insert(m, i);
m = matmul(&m, &a, p);
}
let mut k = [[1, 0], [0, 1]];
let x = mat_pow(&a, sq, p);
let mut ans = sq * sq;
for i in 0..(sq + 1) {
if let Some(&y) = map.get(&k) {
ans = std::cmp::min(ans, sq * i + y);
}
k = matmul(&k, &x, p);
}
if ans < sq * sq {
let ans = phi.checked_mul(ans).expect("overflow") + v;
println!("{}", ans);
} else {
println!("-1");
}
}
}
fn main() {
run();
}
akakimidori