結果
問題 |
No.950 行列累乗
|
ユーザー |
|
提出日時 | 2025-09-03 00:30:18 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 444 ms / 2,000 ms |
コード長 | 6,986 bytes |
コンパイル時間 | 2,937 ms |
コンパイル使用メモリ | 192,264 KB |
実行使用メモリ | 19,456 KB |
最終ジャッジ日時 | 2025-09-03 00:30:31 |
合計ジャッジ時間 | 11,709 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 57 |
ソースコード
#include <algorithm> #include <cassert> #include <cstdio> #include <cstdlib> #include <cstring> #include <cmath> #include <iostream> #include <numeric> #include <vector> #include <map> #include <set> #include <queue> #include <functional> #include <iomanip> #include <ranges> using namespace std; using ll = long long; template<typename T> auto range(T s, T e) { return views::iota(s, max(s, e)); } template<typename T> auto range(T n) { return range<T>(0, n); } template<typename T> void take(vector<T>& vec, int n) { vec.resize(n); for (int i = 0; i < n; ++i) cin >> vec[i]; } template<class... Args> void sout(const Args &...args) { ((cout << args << ' '), ...); } template<class... Args> void soutn(const Args &...args) { ((cout << args << ' '), ...); cout << '\n'; } template<typename T1, typename T2> struct In2 { T1 a; T2 b; friend std::istream& operator>>(std::istream& is, In2& obj) { T1 t1; T2 t2; is >> t1 >> t2; obj = {t1, t2}; return is; } }; template<typename T1, typename T2, typename T3> struct In3 { T1 a; T2 b; T3 c; friend std::istream& operator>>(std::istream& is, In3& obj) { T1 t1; T2 t2; T3 t3; is >> t1 >> t2 >> t3; obj = {t1, t2, t3}; return is; } }; template<typename T1, typename T2, typename T3, typename T4> struct In4 { T1 a; T2 b; T3 c; T4 d; friend std::istream& operator>>(std::istream& is, In4& obj) { T1 t1; T2 t2; T3 t3; T4 t4; is >> t1 >> t2 >> t3 >> t4; obj = {t1, t2, t3, t4}; return is; } }; #ifdef LOCAL #include <debug.h> #else #define dump(...) ; #endif ll mod = 0; struct mint { ll x; mint(ll x_ = 0) : x((x_ % mod + mod) % mod) {} mint operator-() const { return mint(-x); } mint &operator+=(const mint &a) { if ((x += a.x) >= mod) x -= mod; return *this; } mint &operator-=(const mint &a) { if ((x += mod - a.x) >= mod) x -= mod; return *this; } mint &operator*=(const mint &a) { (x *= a.x) %= mod; return *this; } mint operator+(const mint &a) const { mint res(*this); return res += a; } mint operator-(const mint &a) const { mint res(*this); return res -= a; } mint operator*(const mint &a) const { mint res(*this); return res *= a; } mint pow(ll t) const { if (!t) return 1; mint a = pow(t >> 1); a *= a; if (t & 1) a *= *this; return a; } mint inv() const { return pow(mod - 2); } mint &operator/=(const mint &a) { return (*this) *= a.inv(); } mint operator/(const mint &a) const { mint res(*this); return res /= a; } auto operator<=>(const mint&) const = default; friend ostream &operator<<(ostream &os, const mint &m) { os << m.x; return os; } friend istream &operator>>(istream &is, mint &m) { is >> m.x; return is; } }; using Field = mint; typedef vector<Field> Vec; typedef vector<Vec> Mat; template<typename T> Mat from_flat(vector<T> vec, int n, int m) { Mat res(n, Vec(m)); for (int i : range(n)) for (int j : range(m)) res[i][j] = vec[i * m + j]; return res; } Mat mul(Mat mat1, Mat mat2) { int h = mat1.size(), in = mat1[0].size(), w = mat2[0].size(); assert((int)mat2.size() == in); Mat ret(h, Vec(w, 0)); for (int y : range(h)) for (int x : range(w)) for (int i : range(in)) { ret[y][x] += mat1[y][i] * mat2[i][x]; } return ret; } Mat pow(Mat mat, ll power) { int n = mat.size(); Mat ret(n, Vec(n, 0)); for (int i : range(n)) ret[i][i] = 1; while (power > 0) { if (power & 1) ret = mul(ret, mat); mat = mul(mat, mat); power >>= 1; } return ret; } ll inv(ll a, ll p) { return (a == 1 ? 1 : (1 - p * inv(p % a, a)) / a + p); } // Find minimum n >= n0 such that a^n = b (mod p). // If there is no n, returns -1. ll discrete_log(ll a, ll b, ll p, ll n0) { const ll cycle_max = p; map<ll, ll> mp; // mp[a^j] = j ll s = 0; ll apow = 1; for (; s * s <= cycle_max; ++s) { if (s >= n0) { if (apow == b) return s; if (mp.count(apow)) return -1; } mp[apow] = s; apow = a * apow % p; } dump(mp); // a^{ks + i} = b <=> a^i = b (a^{-s})^k ll a_invs = inv(apow, p); ll a_invsk = a_invs; for (ll k = 1; k * s <= cycle_max; k++) { ll rhs = b * a_invsk % p; if (mp.count(rhs)) return k * s + mp[rhs]; a_invsk *= a_invs; a_invsk %= p; } return -1; } mint det(Mat A) { return A[0][0] * A[1][1] - A[0][1] * A[1][0]; } Mat inverse(Mat A) { mint div = det(A); return {{A[1][1] / div, -A[0][1] / div}, {-A[1][0] / div, A[0][0] / div}}; } ll discrete_log_mat(Mat a, Mat b, ll cycle_max) { map<Mat, ll> mp; // mp[a^j] = j ll s = 0; Mat apow = {{1, 0}, {0, 1}}; for (; s * s <= cycle_max; ++s) { if (apow == b) return s; if (mp.count(apow)) return -1; mp[apow] = s; apow = mul(a, apow); } // a^{ks + i} = b <=> a^i = b (a^{-s})^k Mat a_invs = inverse(apow); Mat a_invsk = a_invs; for (ll k = 1; k * s <= cycle_max; k++) { Mat rhs = mul(b, a_invsk); if (mp.count(rhs)) return k * s + mp[rhs]; a_invsk = mul(a_invsk, a_invs); } return -1; } namespace solver { using RetType = ll; Mat A, B; void read() { cin >> mod; vector<ll> a_flat, b_flat; take(a_flat, 4); take(b_flat, 4); A = from_flat(a_flat, 2, 2); B = from_flat(b_flat, 2, 2); } RetType run() { for (int s : range(1, 3)) if (pow(A, s) == B) { dump("trivial"); return s; } mint detA = det(A); mint trA = A[0][0] + A[1][1]; if (detA == 0) { dump("detA = 0"); if (trA == 0) return -1; ll res = -1; for (int i : range(2)) for (int j : range(2)) { if (A[i][j] != 0) { mint b = B[i][j] * A[i][j].inv(); ll s = discrete_log(trA.x, b.x, mod, 1); if (s == -1) { dump("impossible", i, j); return -1; } if (res >= 0 && res != s) { dump("conflict", i, j) return -1; } res = s; } } return res + 1; } mint detB = det(B); ll m1 = discrete_log(detA.x, 1, mod, 1); ll m0 = discrete_log(detA.x, detB.x, mod, 1); dump(m0, m1); assert(m1 >= 0); if (m0 == -1) return -1; // m = m0 + k * m1 // A^m0 * (A^m1)^k = B // (A^m1)^k = (A^-m0) B // // we rewrite this as A^k = B B = mul(pow(inverse(A), m0), B); A = pow(A, m1); dump(A); dump(B); ll m = discrete_log_mat(A, B, 3 * mod); dump(m); if (m == -1) return -1; return m0 + m * m1; for (ll m = 0; m <= 1000; ++m) { if (pow(A, m) == B) return m0 + m * m1; } return -666; } void stress() { for (ll s : range(1, 100)) { B = pow(A, s); dump(B); ll res = run(); dump(s, res); } } /* void stress() { ll a = 7, p = 89; for (ll x : range(p)) { ll res = discrete_log(a, power(a, x, p), p); dump(x, power(a, x, p), res); } } */ } // namespace template <typename F> void run(F f) { if constexpr (std::is_same_v<decltype(f()), void>) f(); else cout << f() << endl; } int main(int argc, char** argv) { int testcase = 1; if (argc > 1) testcase = atoi(argv[1]); while (testcase--) { solver::read(); } // solver::stress(); run(solver::run); }