結果
| 問題 |
No.981 一般冪乗根
|
| ユーザー |
|
| 提出日時 | 2022-05-20 02:36:24 |
| 言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 10 ms / 6,000 ms |
| コード長 | 4,433 bytes |
| コンパイル時間 | 1,672 ms |
| コンパイル使用メモリ | 178,540 KB |
| 実行使用メモリ | 10,140 KB |
| 最終ジャッジ日時 | 2024-09-19 07:29:53 |
| 合計ジャッジ時間 | 66,895 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 42 TLE * 2 |
ソースコード
#include <bits/stdc++.h>
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
using namespace std;
typedef long long lint;
typedef pair<lint, lint> pi;
const int MAXN = 30005;
const int mod = 1e9 + 7;
struct mint {
int val;
mint() { val = 0; }
mint(const lint& v) {
val = (-mod <= v && v < mod) ? v : v % mod;
if (val < 0) val += mod;
}
friend ostream& operator<<(ostream& os, const mint& a) { return os << a.val; }
friend bool operator==(const mint& a, const mint& b) { return a.val == b.val; }
friend bool operator!=(const mint& a, const mint& b) { return !(a == b); }
friend bool operator<(const mint& a, const mint& b) { return a.val < b.val; }
mint operator-() const { return mint(-val); }
mint& operator+=(const mint& m) { if ((val += m.val) >= mod) val -= mod; return *this; }
mint& operator-=(const mint& m) { if ((val -= m.val) < 0) val += mod; return *this; }
mint& operator*=(const mint& m) { val = (lint)val*m.val%mod; return *this; }
friend mint ipow(mint a, lint p) {
mint ans = 1; for (; p; p /= 2, a *= a) if (p&1) ans *= a;
return ans;
}
friend mint inv(const mint& a) { assert(a.val); return ipow(a, mod - 2); }
mint& operator/=(const mint& m) { return (*this) *= inv(m); }
friend mint operator+(mint a, const mint& b) { return a += b; }
friend mint operator-(mint a, const mint& b) { return a -= b; }
friend mint operator*(mint a, const mint& b) { return a *= b; }
friend mint operator/(mint a, const mint& b) { return a /= b; }
operator int64_t() const {return val; }
};
lint gcd(lint x, lint y){ return y ? gcd(y, x%y) : x; }
lint mul(lint x, lint y, lint p){ return (__int128) x * y % p; }
// find x such that x^k == a (mod p)
// time: log^2 q + log q min(p^0.5, q^0.25)
// https://judge.yosupo.jp/submission/78315
lint KthRootModPrime(lint a, lint k, lint p) {
auto pow = [](lint a, lint n, lint p) -> lint {
lint r = 1;
for (; n > 0; n >>= 1, a = mul(a, a, p))
if (n % 2 == 1) r = mul(r, a, p);
return r;
};
auto inv = [](lint a, lint p) -> lint {
a %= p;
lint u = 1, v = 0;
lint b = p;
while (b > 0) {
lint q = a / b;
a %= b;
u -= mul(v, q, p);
u = (u % p + p) % p;
swap(u, v);
swap(a, b);
}
return u < 0 ? u + p : u;
};
auto peth_root = [&](lint a, lint p, int e,
lint mod) -> lint {
lint q = mod - 1;
int s = 0;
while (q % p == 0) {
q /= p;
++s;
}
lint pe = pow(p, e, mod);
lint ans = pow(a, ((__int128)mul(pe - 1, inv(q, pe), pe) * q + 1) / pe, mod);
lint c = 2;
while (pow(c, (mod - 1) / p, mod) == 1) ++c;
c = pow(c, q, mod);
unordered_map<lint, int> map;
lint add = 1;
int v = (int)std::sqrt((double)(s - e) * p) + 1;
lint mult = pow(c, mul(v, pow(p, s - 1, mod - 1), mod - 1), mod);
for (int i = 0; i <= v; ++i) {
map[add] = i;
add = mul(add, mult, mod);
}
mult = inv(pow(c, pow(p, s - 1, mod - 1), mod), mod);
for (int i = e; i < s; ++i) {
lint err = mul(a, inv(pow(ans, pe, mod), mod), mod);
lint target = pow(err, pow(p, s - 1 - i, mod - 1), mod);
for (int j = 0; j <= v; ++j) {
if (map.find(target) != map.end()) {
int x = map[target];
ans = mul(ans, pow(c, mul(j + mul(v, x, mod - 1), pow(p, i - e, mod - 1), mod - 1), mod), mod);
break;
}
target = mul(target, mult, mod);
assert(j != v);
}
}
return ans;
};
if (k > 0 && a == 0) return 0;
k %= p - 1;
lint g = gcd(k, p - 1);
if (pow(a, (p - 1) / g, p) != 1) return -1;
a = pow(a, inv(k / g, (p - 1) / g), p);
for (lint div = 2; div * div <= g; ++div) {
int sz = 0;
while (g % div == 0) {
g /= div;
++sz;
}
if (sz > 0) {
lint b = peth_root(a, div, sz, p);
a = b;
}
}
if (g > 1) a = peth_root(a, g, 1, p);
return a;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int tc; cin >> tc;
while(tc--){
lint p, k, a;
cin >> p >> k >> a;
cout << KthRootModPrime(a, k, p) << "\n";
}
/*
int n; cin >> n;
vector<pair<mint, mint>> v;
for(int i = 0; i < n; i++){
string s; cin >> s;
mint ans = 0;
for(auto &i : s){
ans = ans * mint(10) + mint(i - '0');
}
mint x = ipow(ans, mod / 2);
mint y = ipow(-ans, mod / 2);
cout << x*x << " " << y*y << "\n";
if(x * x != ans){
assert(y * y == -ans);
v.emplace_back(0, y);
}
else v.emplace_back(x, 0);
}
*/
}