結果

問題 No.981 一般冪乗根
ユーザー Min_25Min_25
提出日時 2020-02-09 09:33:22
言語 C++14
(gcc 13.2.0 + boost 1.83.0)
結果
AC  
実行時間 5 ms / 6,000 ms
コード長 5,958 bytes
コンパイル時間 1,467 ms
コンパイル使用メモリ 107,228 KB
実行使用メモリ 5,376 KB
最終ジャッジ日時 2024-04-18 03:11:44
合計ジャッジ時間 42,925 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 3 ms
5,248 KB
testcase_01 AC 3 ms
5,248 KB
testcase_02 AC 3 ms
5,376 KB
testcase_03 AC 3 ms
5,376 KB
testcase_04 AC 3 ms
5,376 KB
testcase_05 AC 3 ms
5,376 KB
testcase_06 AC 3 ms
5,376 KB
testcase_07 AC 3 ms
5,376 KB
testcase_08 AC 3 ms
5,376 KB
testcase_09 AC 3 ms
5,376 KB
testcase_10 AC 3 ms
5,376 KB
testcase_11 AC 3 ms
5,376 KB
testcase_12 AC 3 ms
5,376 KB
testcase_13 AC 3 ms
5,376 KB
testcase_14 AC 3 ms
5,376 KB
testcase_15 AC 3 ms
5,376 KB
testcase_16 AC 3 ms
5,376 KB
testcase_17 AC 3 ms
5,376 KB
testcase_18 AC 3 ms
5,376 KB
testcase_19 AC 3 ms
5,376 KB
testcase_20 AC 3 ms
5,376 KB
testcase_21 AC 3 ms
5,376 KB
testcase_22 AC 3 ms
5,376 KB
testcase_23 AC 3 ms
5,376 KB
testcase_24 AC 3 ms
5,376 KB
testcase_25 AC 4 ms
5,376 KB
testcase_26 AC 3 ms
5,376 KB
testcase_27 AC 2 ms
5,376 KB
testcase_28 AC 5 ms
5,376 KB
evil_60bit1.txt WA -
evil_60bit2.txt WA -
evil_60bit3.txt WA -
evil_hack AC 2 ms
5,376 KB
evil_hard_random RE -
evil_hard_safeprime.txt RE -
evil_hard_tonelli0 RE -
evil_hard_tonelli1 RE -
evil_hard_tonelli2 RE -
evil_hard_tonelli3 RE -
evil_sefeprime1.txt WA -
evil_sefeprime2.txt WA -
evil_sefeprime3.txt WA -
evil_tonelli1.txt WA -
evil_tonelli2.txt WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <cassert>
#include <cmath>
#include <cstring>

#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <functional>
#include <stack>
#include <queue>

#include <tuple>

using namespace std;

using i64 = long long;
using u8 = unsigned char;
using u32 = unsigned;
using u64 = unsigned long long;
using f80 = long double;

using i128 = __int128_t;
using u128 = __uint128_t;

struct Mod64 {
  Mod64() : x(0) {}
  Mod64(u64 n) : x(init(n)) {}

  static u64 modulus() { return mod; }
  static u64 init(u64 w) { return reduce(u128(w) * r2); }
  static void set_mod(u64 m) {
    mod = m; assert(mod & 1);
    inv = m; for (int i = 0; i < 5; ++i) inv *= 2 - inv * m;
    r2 = -u128(m) % m;
  }
  static u64 reduce(u128 x) {
    u64 y = u64(x >> 64) - u64((u128(u64(x) * inv) * mod) >> 64);
    return i64(y) < 0 ? y + mod : y;
  }
  u64 get() const { return reduce(x); }
  void set(u64 n) { x = n; }
  Mod64 pow(u64 e) const {
    Mod64 ret = Mod64(1);
    for (Mod64 b = *this; e; e >>= 1, b *= b) if (e & 1) ret *= b;
    return ret;
  }
  Mod64 inverse() const { return pow(mod - 2); }
  Mod64& operator += (Mod64 rhs) { if (i64(x += rhs.x - mod) < 0) x += mod; return *this; }
  Mod64& operator -= (Mod64 rhs) { if (i64(x -= rhs.x) < 0) x += mod; return *this; }
  Mod64& operator *= (Mod64 rhs) { x = reduce(u128(x) * rhs.x); return *this; }
  Mod64 operator + (Mod64 rhs) const { return Mod64(*this) += rhs; }
  Mod64 operator - (Mod64 rhs) const { return Mod64(*this) -= rhs; }
  Mod64 operator * (Mod64 rhs) const { return Mod64(*this) *= rhs; }
  bool operator == (const Mod64& rhs) const { return x == rhs.x; }
  bool operator != (const Mod64& rhs) const { return x != rhs.x; }

  friend ostream& operator << (ostream& os, const Mod64& m) { return os << m.get(); }

  // ...
  int operator & (int t) const { return x & t; }

  static u64 mod, inv, r2;
  u64 x;
};
u64 Mod64::mod, Mod64::inv, Mod64::r2;

template <typename T>
struct Memo {
  Memo(const T& g, int s, int period) 
      : size(1 << __lg(min(s, period))), mask(size - 1), period(period),
        vs(size), os(size + 1) {
    T x(1);
    for (int i = 0; i < size; ++i, x *= g) os[x & mask]++;
    for (int i = 1; i < size; ++i) os[i] += os[i - 1];
    x = 1;
    for (int i = 0; i < size; ++i, x *= g) vs[--os[x & mask]] = {x, i};
    gpow = x;
    os[size] = size;
  }
  int find(T x) const {
    for (int t = 0; t < period; t += size, x *= gpow) {
      for (int m = (x & mask), i = os[m]; i < os[m + 1]; ++i) {
        if (x == vs[i].first) {
          int ret = vs[i].second - t;
          return ret < 0 ? ret + period : ret;
        }
      }
    }
    assert(0);
  }
  T gpow;
  int size, mask, period;
  vector< pair<T, int> > vs;
  vector<int> os;
};

vector< pair<i64, int> > factors(i64 n) {
  vector< pair<i64, int> > ret;
  for (i64 i = 2; i128(i) * i <= n; ++i) {
    if (n % i == 0) {
      int e = 1; n /= i;
      while (n % i == 0) n /= i, ++e;
      ret.emplace_back(i, e);
    }
  }
  if (n > 1) ret.emplace_back(n, 1);
  return ret;
}

i64 mod_inv(i64 a, i64 mod) {
  i64 b = mod, s = 1, u = 0;
  while (b) {
    i64 q = a / b;
    swap(b, a %= b);
    swap(s -= q * u, u);
  }
  if (a != 1) assert(0);
  return s < 0 ? s + mod : s;
}

Mod64 msqrtp_p(Mod64 a, i64 p, int e, i64 mod) {
  const Mod64 one(1);

  i64 q = mod - 1; int s1 = 0;
  while (q % p == 0) q /= p, ++s1;
 
  i64 ppows[65] = {1};
  for (int i = 1, m = max(e, s1); i <= m; ++i) ppows[i] = ppows[i - 1] * p;

  i64 pe = ppows[e], d = mod_inv(pe - q % pe, pe) * q;
  Mod64 r = a.pow((d + 1) / pe), t = a.pow(d);
  if (t == one) return r;

  int s2 = 1;
  for (Mod64 t2 = t.pow(p); t2 != one; t2 = t2.pow(p), ++s2);

  Mod64 c, g, u;
  for (Mod64 z = 2; ; z += one) {
    c = z.pow(q), g = c.pow(ppows[s1 - 1]);
    if (g != one) break;
  }

  c = c.pow(ppows[s1 - s2 - e]);
  Memo<Mod64> memo(g, int(sqrt(p * s2)), p);

  for (Mod64 u = c.pow(ppows[e]); t != one; u = u.pow(p), c = c.pow(p), --s2) {
    int i = memo.find(t.pow(ppows[s2 - 1]));
    if (i > 0) t *= u.pow(p - i), r *= c.pow(p - i);
  }
  return r;
}

i64 msqrtn_p(i64 a, i64 n, i64 p) {
  assert(n >= 1);
  a %= p, n %= p - 1;
  if (a <= 1) return a;
  i64 g = __gcd(p - 1, n);
  Mod64::set_mod(p);
  Mod64 ma(a), one(1);
  if (ma.pow((p - 1) / g) != one) return -1;
  ma = ma.pow(mod_inv(n / g, (p - 1) / g));
  for (auto pp : factors(g)) {
    ma = msqrtp_p(ma, pp.first, pp.second, p);
  }
  return ma.get();
}

// -----

i64 pow_mod(i64 a, i64 e, i64 mod) {
  i64 ret = 1;
  for (; e; e >>= 1, a = i128(a) * a % mod) {
    if (e & 1) ret = i128(ret) * a % mod;
  }
  return ret;
}

void verify() {
  for (int p = 2; p <= 500; ++p) {
    auto f = factors(p);
    if (f.size() == 1 && f[0].second == 1) {
      for (int k = 1; k <= p; ++k) {
        vector<bool> exists(p);
        for (int a = 0; a < p; ++a) {
          exists[pow_mod(a, k, p)] = true;
        }
        for (int b = 0; b < p; ++b) {
          int a = msqrtn_p(b, k, p);
          if (a < 0) assert(!exists[b]);
          else assert(pow_mod(a, k, p) == b);
        }
      }
      printf("%d: ok\n", p);
    }
  }
  {
    const i64 b = 604438754303967844;
    const int k = 499999273;
    const i64 p = 999997092002114117;
    i64 a = msqrtn_p(b, k, p);
    assert(a >= 0 && pow_mod(a, k, p) == b);
  }
  {
    const i64 p = 1300820172573992383;
    for (int k = 3; k < 100000; k *= 3) {
      for (int a = 1; a < 1000; ++a) {
        auto t = msqrtn_p(pow_mod(a, k, p), k, p);
        assert(pow_mod(t, k, p) == pow_mod(a, k, p));
      }
    }
  }
}

void solve() {
  // verify();
  int T; scanf("%d", &T);
  for (; T; --T) {
    int p, k, a; scanf("%d %d %d", &p, &k, &a);
    i64 ans = msqrtn_p(a, k, p);
    printf("%lld\n", ans);
  }
}

int main() {
  clock_t beg = clock();
  solve();
  clock_t end = clock();
  fprintf(stderr, "%.3f sec\n", double(end - beg) / CLOCKS_PER_SEC);
  return 0;
}
0