結果

問題 No.950 行列累乗
ユーザー 37zigen37zigen
提出日時 2019-12-13 07:03:25
言語 Java19
(openjdk 21)
結果
TLE  
実行時間 -
コード長 9,594 bytes
コンパイル時間 2,619 ms
コンパイル使用メモリ 80,904 KB
実行使用メモリ 68,984 KB
最終ジャッジ日時 2023-09-08 15:50:25
合計ジャッジ時間 9,116 ms
ジャッジサーバーID
(参考情報)
judge12 / judge13
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 TLE -
testcase_01 -- -
testcase_02 -- -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
testcase_37 -- -
testcase_38 -- -
testcase_39 -- -
testcase_40 -- -
testcase_41 -- -
testcase_42 -- -
testcase_43 -- -
testcase_44 -- -
testcase_45 -- -
testcase_46 -- -
testcase_47 -- -
testcase_48 -- -
testcase_49 -- -
testcase_50 -- -
testcase_51 -- -
testcase_52 -- -
testcase_53 -- -
testcase_54 -- -
testcase_55 -- -
testcase_56 -- -
testcase_57 -- -
testcase_58 -- -
testcase_59 -- -
testcase_60 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.io.*;
import java.util.*;

class Main {
	public static void main(String[] args) {
		new Main().run();
	}

	long MOD;

	class K {// quadratic field
		long a, b, base;

		public K(long a, long b) {
			this.a = a;
			this.b = b;
		}

		K add(K o) {
			return new K(a + o.a, b + o.b);
		}

		K sub(K o) {
			return new K((a - o.a + MOD) % MOD, (b - o.b + MOD) % MOD);
		}

		K mul(K o) {
			return new K((a * o.a + b * o.b * base % MOD) % MOD, (a * o.b + o.a * b) % MOD);
		}

		// (a+b√w)/(c+d√w)=(a+b√w)(c-d√w)/(cc-dd)
		K div(K o) {
			long den = (o.b * o.b - o.a * o.a) % MOD;
			den = (den % MOD + MOD) % MOD;
			K ret = this.div(o);
			ret.a = ret.a * inv(den) % MOD;
			ret.b = ret.b * inv(den) % MOD;
			return ret;
		}
	}

	long[][] pow(long[][] a, long n) {
		long[][] ret = new long[2][2];
		ret[0][0] = ret[1][1] = 1;
		for (; n > 0; n >>= 1, a = mul(a, a)) {
			if (n % 2 == 1)
				ret = mul(ret, a);
		}
		return ret;
	}

	long[][] mul(long coe, long[][] a) {
		long[][] ret = new long[a.length][a[0].length];
		for (int i = 0; i < a.length; ++i)
			for (int j = 0; j < a[i].length; ++j)
				ret[i][j] = (MOD + coe * a[i][j] % MOD) % MOD;
		return ret;
	}

	long[][] mul(long[][] a, long[][] b) {
		long[][] ret = new long[a.length][b[0].length];
		for (int i = 0; i < a.length; ++i) {
			for (int j = 0; j < b[i].length; ++j) {
				for (int k = 0; k < a[i].length; ++k) {
					ret[i][j] += a[i][k] * b[k][j] % MOD;
					ret[i][j] = (ret[i][j] % MOD + MOD) % MOD;
				}
			}
		}
		return ret;
	}

	long det(long[][] a) {
		return (a[0][0] * a[1][1] % MOD - a[0][1] * a[1][0] % MOD + MOD) % MOD;
	}

	long[][] rndmat(long p) {
		long[][] ret = new long[2][2];
		Random rnd = new Random();
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				ret[i][j] = rnd.nextInt((int) p);
		return ret;
	}

	void run() {

//		long p = 103;
//		MOD = p;
//		for (int i = 0; i < 100000; ++i) {
//			long[][] a = rndmat(p);
//			long[][] b = rndmat(p);
//			// tr(a, b);
//			long ans0 = exact(a, b, p);
//			long ans1 = solve2(a, b, p);
//			if (ans1 != -99 && ans0 != -1) {
//				tr(ans0, ans1);
//			} // if(ans0!=ans1){ // tr(ans0,ans1,a,b,p); //}
//		}

		/**/
		Scanner sc = new Scanner(System.in);
		long p = sc.nextLong();
		long[][] a = new long[2][2];
		long[][] b = new long[2][2];
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				a[i][j] = sc.nextLong();
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				b[i][j] = sc.nextLong();
		//System.out.println(exact(a, b, p));
		System.out.println(solve2(a, b, p));
		/**/
	}

	long solve2(long[][] a, long[][] b, long p) {
		if (p == 2)
			return exact(a, b, p);
		Scanner sc = new Scanner(System.in);
		MOD = p;
		long w = (a[0][0] * a[0][0] % MOD - 2 * a[0][0] * a[1][1] % MOD + a[1][1] * a[1][1] % MOD
				+ 4 * a[0][1] * a[1][0] % MOD) % MOD;
		if (!isQuadraticResidue(w)) {
			throw new AssertionError();
		}
		w = sqrt((w + MOD) % MOD, p);
		long det0 = det(a);
		long det1 = det(b);
		if (det0 == 0 || det1 == 0)
			return exact(a, b, p);
		long eigen1 = inv(2) * (-w + a[0][0] + a[1][1]) % MOD;
		long eigen2 = inv(2) * (+w + a[0][0] + a[1][1]) % MOD;
		long[][] S = new long[2][2];
		long[][] J = new long[2][2];
		long[][] eigenvec1 = new long[2][1];
		long[][] eigenvec2 = new long[2][1];
		if (a[1][0] == 0) {
			// {{a,b}
			// {0,d}}
			eigen1 = a[0][0];
			eigen2 = a[1][1];
			eigenvec1 = new long[][] { { 1 }, { 0 } };
			eigenvec2 = new long[][] { { (MOD - a[0][1]) % MOD }, { ((a[0][0] - a[1][1]) % MOD + MOD) % MOD } };
			S = new long[][] { { eigenvec1[0][0], eigenvec2[0][0] }, { eigenvec1[1][0], eigenvec2[1][0] } };
			J = new long[][] { { eigen1, 1 }, { 0, eigen2 } };
			if (!equiv(mul(a, eigenvec1), mul(eigen1, eigenvec1))
					|| !equiv(mul(a, eigenvec2), mul(eigen2, eigenvec2))) {
				tr(a, eigenvec1, eigen1, MOD, mul(a, eigenvec1), mul(eigen1, eigenvec1));
				tr(a, eigenvec2, eigen2, MOD, mul(a, eigenvec2), mul(eigen2, eigenvec2));
				throw new AssertionError();
			}
		} else if (eigen1 != eigen2) {
			eigenvec1 = new long[][] { { (MOD - (-a[0][0] + a[1][1] + w) % MOD) % MOD }, { 2 * a[1][0] % MOD } };
			eigenvec2 = new long[][] { { (MOD - (-a[0][0] + a[1][1] - w) % MOD) % MOD }, { 2 * a[1][0] % MOD } };
			S = new long[][] { { eigenvec1[0][0], eigenvec2[0][0] }, { eigenvec1[1][0], eigenvec2[1][0] } };
			J = new long[][] { { eigen1, 0 }, { 0, eigen2 } };
			if (!equiv(mul(a, eigenvec1), mul(eigen1, eigenvec1))
					|| !equiv(mul(a, eigenvec2), mul(eigen2, eigenvec2))) {
				tr(a, eigenvec1, eigen1, MOD, mul(a, eigenvec1), mul(eigen1, eigenvec1));
				tr(a, eigenvec2, eigen2, MOD, mul(a, eigenvec2), mul(eigen2, eigenvec2));
				throw new AssertionError();
			}
		} else {
			eigenvec2 = new long[2][1];
			if (eigen1 != 0) {
				eigenvec2 = new long[][] { { (MOD - (-a[0][0] + a[1][1] - w) % MOD) % MOD }, { 2 * a[1][0] % MOD } };
			} else {
				eigenvec2 = new long[][] { { (MOD - a[0][1]) % MOD }, { a[0][0] } };
			}
			if (!equiv(mul(a, eigenvec2), mul(eigen1, eigenvec2))) {
				tr(a, eigenvec2, MOD, mul(a, eigenvec2));
				throw new AssertionError();
			}
			long[][] tmp = new long[][] { { a[0][0] - eigen1, a[0][1] }, { a[1][0], a[1][1] - eigen1 } };
			eigenvec1 = mul(tmp, eigenvec2);
			if (eigenvec1[0][0] == 0 && eigenvec1[1][0] == 0)
				throw new AssertionError();
			S = new long[][] { { eigenvec1[0][0], eigenvec2[0][0] }, { eigenvec1[1][0], eigenvec2[1][0] } };
			J = new long[][] { { eigen1, 1 }, { 0, eigen2 } };
		}
		for (int i = 0; i < 2; ++i) {
			for (int j = 0; j < 2; ++j) {
				J[i][j] = (J[i][j] % MOD + MOD) % MOD;
				b[i][j] = (b[i][j] % MOD + MOD) % MOD;
			}
		}
		b = mul(b, S);
		b = mul(invmat(S), b);
		if (eigen1 != eigen2 && eigen1 != 0 && eigen2 != 0) {
			if (b[0][1] != 0 || b[1][0] != 0)
				return -1;
		}
		long sol1 = discretelog(J[0][0], b[0][0]);
		long sol2 = discretelog(J[1][1], b[1][1]);
		if (sol1 == -1 || sol2 == -1)
			return -1;
		long ord1 = ord(J[0][0], MOD);
		long ord2 = ord(J[1][1], MOD);
		long ans1 = sol1;
		long ans2 = sol2;
		if (Math.abs(ans1 - ans2) % gcd(ord1, ord2) != 0)
			return -1;
		while (ans1 != ans2) {
			if (ans1 < ans2)
				ans1 += ord1;
			else
				ans2 += ord2;
		}
		// 2+12x=12+3y
		// 3(4x+y)=10

		while (!equiv(pow(J, ans1), b)) {
			ans1 += lcm(ord1, ord2);
			ans2 += lcm(ord1, ord2);
		}
		return ans1;
	}

	boolean equiv(long[][] a, long[][] b) {
		boolean ret = true;
		if (a[0].length != b[0].length || a[1].length != b[1].length)
			throw new AssertionError();
		for (int i = 0; i < a.length; ++i)
			for (int j = 0; j < a[0].length; ++j)
				ret &= a[i][j] == b[i][j];
		return ret;
	}

	long exact(long[][] a, long[][] b, long p) {
		MOD = p;
		for (int i = 1; i < p * p; ++i) {
			long[][] pw_a = pow(a, i);
			boolean equiv = true;
			for (int j = 0; j < 2; ++j)
				for (int k = 0; k < 2; ++k)
					equiv &= pw_a[j][k] == b[j][k];
			if (equiv)
				return i;
		}
		return -1;
	}

	long inv(long a) {
		return pow(a, MOD - 2);
	}

	long ord(long a, long p) {
		if (a == 1)
			return 1;
		long ret = p - 1;
		for (long div = 2; div * div <= p - 1; ++div) {
			if ((p - 1) % div != 0)
				continue;
			if (pow(a, div) == 1)
				ret = Math.min(ret, div);
			else if (pow(a, (p - 1) / div) == 1)
				ret = Math.min(ret, (p - 1) / div);
		}
		return ret;
	}

	long pow(long a, long n) {
		long ret = 1;
		for (; n > 0; n >>= 1, a = a * a % MOD) {
			if (n % 2 == 1)
				ret = ret * a % MOD;
		}
		return ret;
	}

	// return x s.t. a^x = b && x>0
	long discretelog(long a, long b) {
		if (a == 1) {
			if (b == 1)
				return 1;
			else
				return -1;
		} else if (a == 0) {
			if (b == 0)
				return 1;
			else
				return -1;
		}
		// a^(um+v) = b
		// a^v = b a^(-m)^u
		int m = (int) (Math.sqrt(MOD) + 1);
		long pw = 1;
		HashMap<Long, Integer> map = new HashMap<>();
		for (int v = 0; v <= m; ++v) {
			map.put(pw, v);
			pw = pw * a % MOD;
		}
		long ima = pow(inv(a), m);
		long ipw = 1;
		for (int i = 0; i <= m; ++i) {
			if (map.containsKey(b * ipw % MOD)) {
				long ret = i * m + map.get(b * ipw % MOD);
				if (ret != 0)
					return ret;
			}
			ipw = ipw * ima % MOD;
		}
		return -1;
	}

	long[][] invmat(long[][] a) {
		if (det(a) == 0)
			throw new AssertionError();
		long[][] ret = new long[2][2];
		ret[0][0] = a[1][1];
		ret[1][1] = a[0][0];
		ret[0][1] = (MOD - a[0][1]) % MOD;
		ret[1][0] = (MOD - a[1][0]) % MOD;
		for (int i = 0; i < 2; ++i)
			for (int j = 0; j < 2; ++j)
				ret[i][j] = ret[i][j] * inv(det(a)) % MOD;
		return ret;
	}

	long gcd(long a, long b) {
		if (a > b)
			return gcd(b, a);
		if (a == 0)
			return b;
		return gcd(a, b % a);
	}

	long lcm(long a, long b) {
		return a / gcd(a, b) * b;
	}

	boolean isQuadraticResidue(long a) {
		return pow(a, (MOD - 1) / 2) == 1;
	}

	long sqrt(long a, long p) {
		if (a == 0)
			return 0;
		int b = 0;
		while (pow((b * b % p - a + p) % p, (p - 1) / 2) != p - 1)
			++b;
		long[] d = { 1, 0 };
		long[] m = { b, 1 };
		long n = (p + 1) / 2;
		for (; n > 0; n >>= 1, m = poly_mul(m, m, b, a, p)) {
			if (n % 2 == 1)
				d = poly_mul(d, m, b, a, p);
		}
		return d[0];
	}

	long[] poly_mul(long[] u, long[] v, long b, long a, long p) {
		long[] ret = new long[3];
		for (int i = 0; i < 2; ++i) {
			for (int j = 0; j < 2; ++j) {
				ret[i + j] += u[i] * v[j];
				ret[i + j] %= p;
			}
		}
		ret[0] += ret[2] * (b * b - a);
		ret[0] %= p;
		for (int i = 0; i < ret.length; ++i) {
			while (ret[i] < 0)
				ret[i] += p;
		}
		return Arrays.copyOf(ret, 2);
	}

	void tr(Object... objects) {
		System.out.println(Arrays.deepToString(objects));
	}

}
0