結果

問題 No.950 行列累乗
ユーザー 37zigen37zigen
提出日時 2019-12-13 21:45:31
言語 Java21
(openjdk 21)
結果
RE  
実行時間 -
コード長 10,812 bytes
コンパイル時間 3,194 ms
コンパイル使用メモリ 90,744 KB
実行使用メモリ 55,904 KB
最終ジャッジ日時 2024-06-27 15:38:04
合計ジャッジ時間 14,416 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 168 ms
46,880 KB
testcase_01 AC 193 ms
47,520 KB
testcase_02 RE -
testcase_03 AC 124 ms
41,224 KB
testcase_04 AC 124 ms
41,548 KB
testcase_05 AC 170 ms
44,852 KB
testcase_06 RE -
testcase_07 AC 128 ms
41,524 KB
testcase_08 RE -
testcase_09 RE -
testcase_10 RE -
testcase_11 AC 128 ms
41,568 KB
testcase_12 AC 128 ms
41,024 KB
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 AC 115 ms
40,068 KB
testcase_18 AC 128 ms
41,668 KB
testcase_19 AC 129 ms
40,980 KB
testcase_20 AC 130 ms
41,460 KB
testcase_21 RE -
testcase_22 RE -
testcase_23 AC 204 ms
47,628 KB
testcase_24 RE -
testcase_25 AC 201 ms
46,004 KB
testcase_26 AC 205 ms
47,816 KB
testcase_27 AC 130 ms
41,496 KB
testcase_28 RE -
testcase_29 RE -
testcase_30 RE -
testcase_31 RE -
testcase_32 RE -
testcase_33 RE -
testcase_34 RE -
testcase_35 RE -
testcase_36 AC 115 ms
40,204 KB
testcase_37 AC 169 ms
47,440 KB
testcase_38 AC 190 ms
47,924 KB
testcase_39 AC 131 ms
41,228 KB
testcase_40 AC 224 ms
51,140 KB
testcase_41 AC 193 ms
50,688 KB
testcase_42 RE -
testcase_43 RE -
testcase_44 RE -
testcase_45 RE -
testcase_46 RE -
testcase_47 RE -
testcase_48 RE -
testcase_49 RE -
testcase_50 RE -
testcase_51 RE -
testcase_52 AC 177 ms
47,184 KB
testcase_53 AC 189 ms
47,768 KB
testcase_54 RE -
testcase_55 AC 129 ms
41,396 KB
testcase_56 AC 178 ms
47,516 KB
testcase_57 AC 128 ms
41,076 KB
testcase_58 AC 128 ms
41,344 KB
testcase_59 AC 126 ms
40,840 KB
testcase_60 AC 129 ms
41,164 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.io.*;
import java.util.*;

class Main {
	public static void main(String[] args) {
		new Main().run();
	}

	long MOD;

	class K {// quadratic field
		long a, b, base;

		public K(long a, long b) {
			this.a = a;
			this.b = b;
		}

		K add(K o) {
			return new K(a + o.a, b + o.b);
		}

		K sub(K o) {
			return new K((a - o.a + MOD) % MOD, (b - o.b + MOD) % MOD);
		}

		K mul(K o) {
			return new K((a * o.a + b * o.b * base % MOD) % MOD, (a * o.b + o.a * b) % MOD);
		}

		// (a+b√w)/(c+d√w)=(a+b√w)(c-d√w)/(cc-dd)
		K div(K o) {
			long den = (o.b * o.b - o.a * o.a) % MOD;
			den = (den % MOD + MOD) % MOD;
			K ret = this.div(o);
			ret.a = ret.a * inv(den, MOD) % MOD;
			ret.b = ret.b * inv(den, MOD) % MOD;
			return ret;
		}
	}

	long[][] pow(long[][] a, long n) {
		long[][] ret = new long[2][2];
		ret[0][0] = ret[1][1] = 1;
		for (; n > 0; n >>= 1, a = mul(a, a)) {
			if (n % 2 == 1)
				ret = mul(ret, a);
		}
		return ret;
	}

	long[][] mul(long coe, long[][] a) {
		long[][] ret = new long[a.length][a[0].length];
		for (int i = 0; i < a.length; ++i)
			for (int j = 0; j < a[i].length; ++j)
				ret[i][j] = (MOD + coe * a[i][j] % MOD) % MOD;
		return ret;
	}

	long[][] mul(long[][] a, long[][] b) {
		long[][] ret = new long[a.length][b[0].length];
		for (int i = 0; i < a.length; ++i) {
			for (int j = 0; j < b[i].length; ++j) {
				for (int k = 0; k < a[i].length; ++k) {
					ret[i][j] += a[i][k] * b[k][j] % MOD;
					ret[i][j] = (ret[i][j] % MOD + MOD) % MOD;
				}
			}
		}
		return ret;
	}

	long det(long[][] a) {
		return (a[0][0] * a[1][1] % MOD - a[0][1] * a[1][0] % MOD + MOD) % MOD;
	}

	long[][] rndmat(long p) {
		long[][] ret = new long[2][2];
		Random rnd = new Random();
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				ret[i][j] = rnd.nextInt((int) 7);
		return ret;
	}

	void run() {
//		long p = (long) 7;
//		MOD = p;
//		for (int i = 0;; ++i) {
//			long[][] a = rndmat(p);
//			long[][] b = rndmat(p);
//			long ans0 = exact(a, b, p);
//			long ans1 = solve(a, b, p);
//			if (ans1 != -99 && ans0 != ans1) {
//				tr(a, b);
//				tr(ans0, ans1);
//			}
//		}
//
		Scanner sc = new Scanner(System.in);
		long p = sc.nextLong();
		long[][] a = new long[2][2];
		long[][] b = new long[2][2];
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				a[i][j] = sc.nextLong();
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				b[i][j] = sc.nextLong();
//		System.out.println(exact(a, b, p));
		System.out.println(solve(a, b, p));
	}

	long solve(long[][] a, long[][] b, long p) {
		if (equiv(a, b))
			return 1;
		if (p == 2)
			return exact(a, b, p);
		Scanner sc = new Scanner(System.in);
		MOD = p;
		long w = (a[0][0] * a[0][0] % MOD - 2 * a[0][0] * a[1][1] % MOD + a[1][1] * a[1][1] % MOD
				+ 4 * a[0][1] * a[1][0] % MOD) % MOD;
		if (!isQuadraticResidue(w)) {
//			return -99;
			throw new AssertionError();
		}
		w = sqrt((w + MOD) % MOD, p);
		long det0 = det(a);
		long det1 = det(b);
		if (det0 == 0 && det1 > 0)
			return -1;
//		if (det0 == 0 || det1 == 0)
//			return exact(a, b, p);
		long eigen1 = inv(2, MOD) * (-w + a[0][0] + a[1][1]) % MOD;
		long eigen2 = inv(2, MOD) * (+w + a[0][0] + a[1][1]) % MOD;
		long[][] S = new long[2][2];
		long[][] J = new long[2][2];
		long[][] eigenvec1 = new long[2][1];
		long[][] eigenvec2 = new long[2][1];
		if (a[1][0] == 0) {
			// {{a,b}
			// {0,d}}
			eigen1 = a[0][0];
			eigen2 = a[1][1];
			eigenvec1 = new long[][] { { 1 }, { 0 } };
			eigenvec2 = new long[][] { { (MOD - a[0][1]) % MOD }, { ((a[0][0] - a[1][1]) % MOD + MOD) % MOD } };
			S = new long[][] { { eigenvec1[0][0], eigenvec2[0][0] }, { eigenvec1[1][0], eigenvec2[1][0] } };
			J = new long[][] { { eigen1, 0 }, { 0, eigen2 } };
			if (!equiv(mul(a, eigenvec1), mul(eigen1, eigenvec1))
					|| !equiv(mul(a, eigenvec2), mul(eigen2, eigenvec2))) {
				tr(a, eigenvec1, eigen1, MOD, mul(a, eigenvec1), mul(eigen1, eigenvec1));
				tr(a, eigenvec2, eigen2, MOD, mul(a, eigenvec2), mul(eigen2, eigenvec2));
				throw new AssertionError();
			}
		} else if (eigen1 != eigen2) {
			eigenvec1 = new long[][] { { (MOD - (-a[0][0] + a[1][1] + w) % MOD) % MOD }, { 2 * a[1][0] % MOD } };
			eigenvec2 = new long[][] { { (MOD - (-a[0][0] + a[1][1] - w) % MOD) % MOD }, { 2 * a[1][0] % MOD } };
			S = new long[][] { { eigenvec1[0][0], eigenvec2[0][0] }, { eigenvec1[1][0], eigenvec2[1][0] } };
			J = new long[][] { { eigen1, 0 }, { 0, eigen2 } };
			if (!equiv(mul(a, eigenvec1), mul(eigen1, eigenvec1))
					|| !equiv(mul(a, eigenvec2), mul(eigen2, eigenvec2))) {
				tr(a, eigenvec1, eigen1, MOD, mul(a, eigenvec1), mul(eigen1, eigenvec1));
				tr(a, eigenvec2, eigen2, MOD, mul(a, eigenvec2), mul(eigen2, eigenvec2));
				throw new AssertionError();
			}
		} else {
			eigenvec2 = new long[2][1];
			if (eigen1 != 0) {
				eigenvec2 = new long[][] { { (MOD - (-a[0][0] + a[1][1] - w) % MOD) % MOD }, { 2 * a[1][0] % MOD } };
			} else {
				eigenvec2 = new long[][] { { (MOD - a[0][1]) % MOD }, { a[0][0] } };
			}
			if (!equiv(mul(a, eigenvec2), mul(eigen1, eigenvec2))) {
				tr(a, eigenvec2, MOD, mul(a, eigenvec2));
				throw new AssertionError();
			}
			long[][] tmp = new long[][] { { a[0][0] - eigen1, a[0][1] }, { a[1][0], a[1][1] - eigen1 } };
			eigenvec1 = mul(tmp, eigenvec2);
			if (eigenvec1[0][0] == 0 && eigenvec1[1][0] == 0)
				throw new AssertionError();
			S = new long[][] { { eigenvec1[0][0], eigenvec2[0][0] }, { eigenvec1[1][0], eigenvec2[1][0] } };
			J = new long[][] { { eigen1, 1 }, { 0, eigen2 } };
		}
		for (int i = 0; i < 2; ++i) {
			for (int j = 0; j < 2; ++j) {
				J[i][j] = (J[i][j] % MOD + MOD) % MOD;
				b[i][j] = (b[i][j] % MOD + MOD) % MOD;
			}
		}
		b = mul(b, S);
		b = mul(invmat(S), b);
		if (b[1][0] != 0 && J[1][0] == 0) {
			return -1;
		}
		if (eigen1 != eigen2 && eigen1 != 0 && eigen2 != 0) {
			if (b[0][1] != 0 || b[1][0] != 0)
				return -1;
		}
		if ((b[0][0] > 0 && J[0][0] == 0) || (b[1][1] > 0 && J[1][1] == 0))
			return -1;
		if ((J[0][1] * (J[0][0] + J[1][0]) % MOD != 0 && b[0][1] == 0) || (J[0][1] == 0 && b[0][1] != 0))
			return -1;
		// A=[a b
		// 0 c]
		// A^(k+1)=[a^(k+1) b(a+c)^k
		// 0 c^(k+1)]
		long sol1 = discretelog(J[0][0], b[0][0]);
		long sol2 = discretelog(J[1][1], b[1][1]);
		if (sol1 == -1 || sol2 == -1)
			return -1;
		long ord1 = ord(J[0][0], MOD);
		long ord2 = ord(J[1][1], MOD);
		long ans1 = sol1;
		long ans2 = sol2;
		if (Math.abs(ans1 - ans2) % gcd(ord1, ord2) != 0)
			return -1;
//		long ans = garner(new long[] { sol1, sol2 }, new long[] { ord1, ord2 });
//		ans1 = ans;
//		ans2 = ans;
		while (ans1 != ans2) {
			if (ans1 < ans2) {
				ans1 += (ans2 - ans1 + ord1 - 1) / ord1 * ord1;
			} else {
				ans2 += (ans1 - ans2 + ord2 - 1) / ord2 * ord2;
			}
		}
		while (!equiv(pow(J, ans1), b)) {
			ans1 += lcm(ord1, ord2);
			ans2 += lcm(ord1, ord2);
		}
		return ans1;
	}

	boolean equiv(long[][] a, long[][] b) {
		boolean ret = true;
		if (a[0].length != b[0].length || a[1].length != b[1].length)
			throw new AssertionError();
		for (int i = 0; i < a.length; ++i)
			for (int j = 0; j < a[0].length; ++j)
				ret &= a[i][j] == b[i][j];
		return ret;
	}

	long exact(long[][] a, long[][] b, long p) {
		MOD = p;
		for (int i = 1; i < p * p; ++i) {
			long[][] pw_a = pow(a, i);
			boolean equiv = true;
			for (int j = 0; j < 2; ++j)
				for (int k = 0; k < 2; ++k)
					equiv &= pw_a[j][k] == b[j][k];
			if (equiv)
				return i;
		}
		return -1;
	}

	long inv(long a, long mod) {
		return pow(a, mod - 2);
	}

	long ord(long a, long p) {
		if (a == 0 || a == 1)
			return 1;
		long ret = p - 1;
		for (long div = 2; div * div <= p - 1; ++div) {
			if ((p - 1) % div != 0)
				continue;
			if (pow(a, div) == 1)
				ret = Math.min(ret, div);
			else if (pow(a, (p - 1) / div) == 1)
				ret = Math.min(ret, (p - 1) / div);
		}
		return ret;
	}

	long pow(long a, long n) {
		long ret = 1;
		for (; n > 0; n >>= 1, a = a * a % MOD) {
			if (n % 2 == 1)
				ret = ret * a % MOD;
		}
		return ret;
	}

	// return x s.t. a^x = b && x>0
	long discretelog(long a, long b) {
		if (a == 1) {
			if (b == 1)
				return 1;
			else
				return -1;
		} else if (a == 0) {
			if (b == 0)
				return 1;
			else
				return -1;
		}
		// a^(um+v) = b
		// a^v = b a^(-m)^u
		int m = (int) (Math.sqrt(MOD) + 1);
		long pw = 1;
		HashMap<Long, Integer> map = new HashMap<>();
		for (int v = 0; v <= m; ++v) {
			map.put(pw, v);
			pw = pw * a % MOD;
		}
		long ima = pow(inv(a, MOD), m);
		long ipw = 1;
		for (int i = 0; i <= m; ++i) {
			if (map.containsKey(b * ipw % MOD)) {
				long ret = i * m + map.get(b * ipw % MOD);
				if (ret != 0)
					return ret;
			}
			ipw = ipw * ima % MOD;
		}
		return -1;
	}

	long[][] invmat(long[][] a) {
		if (det(a) == 0)
			throw new AssertionError();
		long[][] ret = new long[2][2];
		ret[0][0] = a[1][1];
		ret[1][1] = a[0][0];
		ret[0][1] = (MOD - a[0][1]) % MOD;
		ret[1][0] = (MOD - a[1][0]) % MOD;
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				ret[i][j] = ret[i][j] * inv(det(a), MOD) % MOD;
		return ret;
	}

	long gcd(long a, long b) {
		if (a > b)
			return gcd(b, a);
		if (a == 0)
			return b;
		return gcd(a, b % a);
	}

	long lcm(long a, long b) {
		return a / gcd(a, b) * b;
	}

	boolean isQuadraticResidue(long a) {
		return pow(a, (MOD - 1) / 2) == 1;
	}

	long sqrt(long a, long p) {
		if (a == 0)
			return 0;
		int b = 0;
		while (pow((b * b % p - a + p) % p, (p - 1) / 2) != p - 1)
			++b;
		long[] d = { 1, 0 };
		long[] m = { b, 1 };
		long n = (p + 1) / 2;
		for (; n > 0; n >>= 1, m = poly_mul(m, m, b, a, p)) {
			if (n % 2 == 1)
				d = poly_mul(d, m, b, a, p);
		}
		return d[0];
	}

	long[] poly_mul(long[] u, long[] v, long b, long a, long p) {
		long[] ret = new long[3];
		for (int i = 0; i < 2; ++i) {
			for (int j = 0; j < 2; ++j) {
				ret[i + j] += u[i] * v[j];
				ret[i + j] %= p;
			}
		}
		ret[0] += ret[2] * (b * b - a);
		ret[0] %= p;
		for (int i = 0; i < ret.length; ++i) {
			while (ret[i] < 0)
				ret[i] += p;
		}
		return Arrays.copyOf(ret, 2);
	}

	long garner(long[] x, long[] mod) {
		assert x.length == mod.length;
		int n = x.length;
		long[] gamma = new long[n];
		for (int i = 0; i < n; i++) {
			long prod = 1;
			for (int j = 0; j < i; j++) {
				prod = prod * mod[j] % mod[i];
			}
			gamma[i] = inv(prod, mod[i]);
		}
		long[] v = new long[n];
		v[0] = x[0];
		for (int i = 1; i < n; i++) {
			long tmp = v[i - 1];
			for (int j = i - 2; j >= 0; j--) {
				tmp = (tmp * mod[j] + v[j]) % mod[i];
			}
			v[i] = (x[i] - tmp) * gamma[i] % mod[i];
			while (v[i] < 0)
				v[i] += mod[i];
		}
		long ret = 0;
		for (int i = v.length - 1; i >= 0; i--) {
			ret = (ret * mod[i] + v[i]);
		}
		return ret;
	}

	void tr(Object... objects) {
		System.out.println(Arrays.deepToString(objects));
	}

}
0