結果

問題 No.950 行列累乗
ユーザー ir5
提出日時 2025-09-03 00:30:18
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 444 ms / 2,000 ms
コード長 6,986 bytes
コンパイル時間 2,937 ms
コンパイル使用メモリ 192,264 KB
実行使用メモリ 19,456 KB
最終ジャッジ日時 2025-09-03 00:30:31
合計ジャッジ時間 11,709 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 57
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iostream>
#include <numeric>
#include <vector>
#include <map>
#include <set>
#include <queue>
#include <functional>
#include <iomanip>
#include <ranges>
using namespace std;
using ll = long long;

template<typename T> auto range(T s, T e) { return views::iota(s, max(s, e)); }
template<typename T> auto range(T n) { return range<T>(0, n); }
template<typename T> void take(vector<T>& vec, int n) { vec.resize(n); for (int i = 0; i < n; ++i) cin >> vec[i]; }
template<class... Args> void sout(const Args &...args) { ((cout << args << ' '), ...); }
template<class... Args> void soutn(const Args &...args) { ((cout << args << ' '), ...); cout << '\n'; }
template<typename T1, typename T2> struct In2 {
  T1 a; T2 b; friend std::istream& operator>>(std::istream& is, In2& obj) { T1 t1; T2 t2; is >> t1 >> t2; obj = {t1, t2}; return is; } };
template<typename T1, typename T2, typename T3> struct In3 {
  T1 a; T2 b; T3 c; friend std::istream& operator>>(std::istream& is, In3& obj) { T1 t1; T2 t2; T3 t3; is >> t1 >> t2 >> t3; obj = {t1, t2, t3}; return is; } };
template<typename T1, typename T2, typename T3, typename T4> struct In4 {
  T1 a; T2 b; T3 c; T4 d; friend std::istream& operator>>(std::istream& is, In4& obj) { T1 t1; T2 t2; T3 t3; T4 t4; is >> t1 >> t2 >> t3 >> t4; obj = {t1, t2, t3, t4}; return is; } };

#ifdef LOCAL
#include <debug.h>
#else
#define dump(...) ;
#endif

ll mod = 0;
struct mint {
  ll x;
  mint(ll x_ = 0) : x((x_ % mod + mod) % mod) {}
  mint operator-() const { return mint(-x); }
  mint &operator+=(const mint &a) { if ((x += a.x) >= mod) x -= mod; return *this; }
  mint &operator-=(const mint &a) { if ((x += mod - a.x) >= mod) x -= mod; return *this; }
  mint &operator*=(const mint &a) { (x *= a.x) %= mod; return *this; }
  mint operator+(const mint &a) const { mint res(*this); return res += a; }
  mint operator-(const mint &a) const { mint res(*this); return res -= a; }
  mint operator*(const mint &a) const { mint res(*this); return res *= a; }
  mint pow(ll t) const { if (!t) return 1; mint a = pow(t >> 1); a *= a; if (t & 1) a *= *this; return a; }
  mint inv() const { return pow(mod - 2); }
  mint &operator/=(const mint &a) { return (*this) *= a.inv(); }
  mint operator/(const mint &a) const { mint res(*this); return res /= a; }
  auto operator<=>(const mint&) const = default;
  friend ostream &operator<<(ostream &os, const mint &m) { os << m.x; return os; }
  friend istream &operator>>(istream &is, mint &m) { is >> m.x; return is; }
};

using Field = mint;
typedef vector<Field> Vec;
typedef vector<Vec> Mat;

template<typename T>
Mat from_flat(vector<T> vec, int n, int m) {
  Mat res(n, Vec(m));
  for (int i : range(n)) for (int j : range(m)) res[i][j] = vec[i * m + j];
  return res;
}

Mat mul(Mat mat1, Mat mat2) {
  int h = mat1.size(), in = mat1[0].size(), w = mat2[0].size();
  assert((int)mat2.size() == in);

  Mat ret(h, Vec(w, 0));
  for (int y : range(h)) for (int x : range(w)) for (int i : range(in)) {
    ret[y][x] += mat1[y][i] * mat2[i][x];
  }

  return ret;
}

Mat pow(Mat mat, ll power) {
  int n = mat.size();

  Mat ret(n, Vec(n, 0));
  for (int i : range(n)) ret[i][i] = 1;

  while (power > 0) {
    if (power & 1)
      ret = mul(ret, mat);
    mat = mul(mat, mat);
    power >>= 1;
  }
  return ret;
}

ll inv(ll a, ll p) { return (a == 1 ? 1 : (1 - p * inv(p % a, a)) / a + p); }

// Find minimum n >= n0 such that a^n = b (mod p).
// If there is no n, returns -1.
ll discrete_log(ll a, ll b, ll p, ll n0) {
  const ll cycle_max = p;
  map<ll, ll> mp; // mp[a^j] = j

  ll s = 0;
  ll apow = 1;
  for (; s * s <= cycle_max; ++s) {
    if (s >= n0) {
      if (apow == b) return s;
      if (mp.count(apow)) return -1;
    }
    mp[apow] = s;
    apow = a * apow % p;
  }
  dump(mp);

  // a^{ks + i} = b  <=>  a^i = b (a^{-s})^k
  ll a_invs = inv(apow, p);
  ll a_invsk = a_invs;
  for (ll k = 1; k * s <= cycle_max; k++) {
    ll rhs = b * a_invsk % p;
    if (mp.count(rhs)) return k * s + mp[rhs];
    a_invsk *= a_invs;
    a_invsk %= p;
  }
  return -1;
}

mint det(Mat A) {
  return A[0][0] * A[1][1] - A[0][1] * A[1][0];
}

Mat inverse(Mat A) {
  mint div = det(A);
  return {{A[1][1] / div, -A[0][1] / div},
          {-A[1][0] / div, A[0][0] / div}};
}

ll discrete_log_mat(Mat a, Mat b, ll cycle_max) {
  map<Mat, ll> mp; // mp[a^j] = j

  ll s = 0;
  Mat apow = {{1, 0}, {0, 1}};
  for (; s * s <= cycle_max; ++s) {
    if (apow == b) return s;
    if (mp.count(apow)) return -1;
    mp[apow] = s;
    apow = mul(a, apow);
  }

  // a^{ks + i} = b  <=>  a^i = b (a^{-s})^k
  Mat a_invs = inverse(apow);
  Mat a_invsk = a_invs;
  for (ll k = 1; k * s <= cycle_max; k++) {
    Mat rhs = mul(b, a_invsk);
    if (mp.count(rhs)) return k * s + mp[rhs];
    a_invsk = mul(a_invsk, a_invs);
  }
  return -1;
}

namespace solver {


using RetType = ll;

Mat A, B;

void read() {
  cin >> mod;
  vector<ll> a_flat, b_flat;
  take(a_flat, 4);
  take(b_flat, 4);
  A = from_flat(a_flat, 2, 2);
  B = from_flat(b_flat, 2, 2);
}

RetType run() {

  for (int s : range(1, 3)) if (pow(A, s) == B) {
    dump("trivial");
    return s;
  }

  mint detA = det(A);
  mint trA = A[0][0] + A[1][1];

  if (detA == 0) {
    dump("detA = 0");

    if (trA == 0) return -1;
    ll res = -1;
    for (int i : range(2))
    for (int j : range(2)) {
      if (A[i][j] != 0) {
        mint b = B[i][j] * A[i][j].inv();
        ll s = discrete_log(trA.x, b.x, mod, 1);

        if (s == -1) {
          dump("impossible", i, j);
          return -1;
        }
        if (res >= 0 && res != s) {
          dump("conflict", i, j)
          return -1;
        }
        res = s;
      }
    }

    return res + 1;
  }

  mint detB = det(B);
  ll m1 = discrete_log(detA.x, 1, mod, 1);
  ll m0 = discrete_log(detA.x, detB.x, mod, 1);
  dump(m0, m1);

  assert(m1 >= 0);
  if (m0 == -1) return -1;

  // m = m0 + k * m1
  // A^m0 * (A^m1)^k = B
  // (A^m1)^k = (A^-m0) B
  //
  // we rewrite this as A^k = B

  B = mul(pow(inverse(A), m0), B);
  A = pow(A, m1);

  dump(A);
  dump(B);

  ll m = discrete_log_mat(A, B, 3 * mod);
  dump(m);
  if (m == -1) return -1;
  return m0 + m * m1;

  for (ll m = 0; m <= 1000; ++m) {
    if (pow(A, m) == B) return m0 + m * m1;
  }
  return -666;
}

void stress() {
  for (ll s : range(1, 100)) {
    B = pow(A, s);
    dump(B);
    ll res = run();
    dump(s, res);
  }
}

/*
void stress() {
  ll a = 7, p = 89;
  for (ll x : range(p)) {
    ll res = discrete_log(a, power(a, x, p), p);
    dump(x, power(a, x, p), res);
  }
}
*/

}  // namespace

template <typename F>
void run(F f) { if constexpr (std::is_same_v<decltype(f()), void>) f(); else cout << f() << endl; }

int main(int argc, char** argv) {
  int testcase = 1;
  if (argc > 1) testcase = atoi(argv[1]);
  while (testcase--) {
    solver::read();
  }
  // solver::stress();
  run(solver::run);
}
0