結果
| 問題 |
No.981 一般冪乗根
|
| ユーザー |
|
| 提出日時 | 2020-12-29 18:10:32 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 8 ms / 6,000 ms |
| コード長 | 21,578 bytes |
| コンパイル時間 | 3,662 ms |
| コンパイル使用メモリ | 233,600 KB |
| 最終ジャッジ日時 | 2025-01-17 08:18:58 |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 42 TLE * 2 |
ソースコード
#include <bits/stdc++.h>
using uint = unsigned int;
using ll = long long;
using ull = unsigned long long;
using ld = long double;
using LL = __int128_t;
using ULL = __uint128_t;
template<typename T> using max_heap = std::priority_queue<T>;
template<typename T> using min_heap = std::priority_queue<T, std::vector<T>, std::greater<T>>;
template<typename T> constexpr T inverse(const T a, const T mod) { return a == 1 ? T{1} : ((a - inverse(mod % a, a)) * mod + 1) / a; }
template<typename T> constexpr std::pair<T, T> extgcd(const T a, const T b)
{
if (a == 0) { return -1 / b; }
if (b == 0) { return 1 / a; }
const T x = inverse(a, b), y = (a * x - 1) / b;
return {x, y};
}
template<typename T, typename V>
inline bool miller_rabin(const T& n, const std::vector<T>& as)
{
auto pow = [&](auto&& self, const V& a, const T k) -> V {
if (k == 0) { return 1; }
if (k % 2 == 0) {
return self(self, (a * a) % V(n), k / 2);
} else {
return (self(self, a, k - 1) * a) % V(n);
}
};
T d = n - 1;
for (; (d & 1) == 0; d >>= 1) {}
for (const T& a : as) {
if (n <= a) { break; }
T s = d;
V x = pow(pow, a, s);
while (x != 1 and x != n - 1 and s != n - 1) {
(x *= x) %= V(n);
s *= 2;
}
if (x != n - 1 and s % 2 == 0) { return false; }
}
return true;
}
inline bool is_prime(const ull n)
{
if (n % 2 == 0) { return n == 2; }
if (n < (1ULL << 32)) {
return miller_rabin<uint, ull>((uint)n, std::vector<uint>{2, 7, 61});
} else {
return miller_rabin<ull, __uint128_t>(n, std::vector<ull>{2, 325, 9375, 28178, 450775, 9780504});
}
}
template<typename T, typename V = T>
T pollard_rho(const T n)
{
if (n % 2 == 0) { return 2; }
if (is_prime(n)) { return n; }
for (T c = 1; c < n; c++) {
if (c == n - 2) { continue; }
auto f = [&](const T x) -> T { return T((V(x) * V(x) + V(c)) % V(n)); };
T x = 2, y = 2, d = 1;
while (d == 1) {
x = f(x), y = f(f(y));
d = std::gcd(std::max(x, y) - std::min(x, y), n);
}
if (d != n) { return d; }
}
return n;
}
std::map<ull, int> prime_factors(const ull n_)
{
std::map<ull, int> ans;
auto factor = [&](auto&& self, const ull n) -> void {
if (n == 1) { return; }
const ull p = (n < (1ULL << 32)) ? (ull)pollard_rho<uint, ull>((uint)n) : pollard_rho<ull, __uint128_t>(n);
if (p == n) {
ans[p]++;
return;
}
self(self, p), self(self, n / p);
};
factor(factor, n_);
return ans;
}
template<typename mint>
mint mod_nthroot(mint A, ull k)
{
const ull P = mint::mod;
if (A == 0) { return 0; }
if (k == 0) { return A; }
const ull g = std::gcd(P - 1, k);
if (A.pow((P - 1) / g)() != 1) { return 0; }
A = A.pow(inverse<ULL>(k / g, (P - 1) / g));
if (g == 1) { return A; }
const auto fs = prime_factors(g);
for (const auto& [p, e] : fs) {
ull pe = 1;
for (int i = 0; i < e; i++) { pe *= p; }
ull q = P - 1, Q = 0;
while (q % p == 0) { q /= p, Q++; }
const ull y = pe - inverse<ULL>(q, pe), z = ((ULL)y * q + 1) / (ULL)pe;
mint X = A.pow(z);
if (Q == 1) {
A = X;
continue;
}
mint Eraser = 1;
const ull h = (P - 1) / p;
for (mint Z = 2;; Z += 1) {
if (Z.pow(h) != 1) {
Eraser = Z.pow(q);
break;
}
}
mint Error = A.pow((ULL)y * q);
mint pEraser = Eraser;
for (ull i = 0; i < Q - 1; i++) { pEraser = pEraser.pow(p); }
const mint ipEraser = pEraser.inv();
std::map<ull, ull> memo;
{
const ull M = std::max(1ULL, (ull)(std::sqrt(p) * std::sqrt(Q - e) / 2)), B = std::max(1ULL, ((ull)p - 1) / M);
const mint ppEraser = pEraser.pow(B);
mint prod = 1;
for (ull i = 0; i < p; i += B, prod *= ppEraser) {
if (memo[prod()]) { break; }
memo[prod()] = i;
}
}
while (Error() != 1) {
ull l = 0;
mint pError = Error;
for (ull i = 0; i < Q; i++) {
const auto npError = pError.pow(p);
if (npError == 1) {
l = Q - (i + 1);
break;
}
pError = npError;
}
ull c = -1;
{
mint small = pError.inv();
for (ull j = 0;; j++, small *= ipEraser) {
if (memo.count(small())) {
const ull i = memo.at(small());
c = i + j;
break;
}
}
}
auto pEraser2 = Eraser.pow(c);
for (ull i = 0; i < l - e; i++) { pEraser2 = pEraser2.pow(p); }
X *= pEraser2, Error *= pEraser2.pow(pe);
}
A = X;
}
return A;
}
struct modinfo
{
constexpr modinfo(const uint mod_, const uint root_, const uint max2p_) : mod{mod_}, root{root_}, max2p{max2p_} {}
constexpr modinfo() : mod{}, root{}, max2p{} {}
void set_mod(const uint mod_) { mod = mod_; }
void set_root(const uint root_) { root = root_; }
void set_max2p(const uint max2p_) { max2p = max2p_; }
uint mod, root, max2p;
};
template<const modinfo& info>
class modint
{
public:
static constexpr const uint& mod = info.mod;
static constexpr const uint& root = info.root;
static constexpr const uint& max2p = info.max2p;
constexpr modint() : m_val{0} {}
constexpr modint(const ll v) : m_val{normll(v)} {}
constexpr modint(const modint& m) = default;
constexpr void set_raw(const uint v) { m_val = v; }
constexpr modint& operator=(const modint& m) { return m_val = m(), (*this); }
constexpr modint& operator=(const ll v) { return m_val = normll(v), (*this); }
constexpr modint operator+() const { return *this; }
constexpr modint operator-() const { return modint{0} - (*this); }
constexpr modint& operator+=(const modint& m) { return m_val = norm(m_val + m()), *this; }
constexpr modint& operator-=(const modint& m) { return m_val = norm(m_val + mod - m()), *this; }
constexpr modint& operator*=(const modint& m) { return m_val = normll((ll)m_val * (ll)m() % (ll)mod), *this; }
constexpr modint& operator/=(const modint& m) { return *this *= m.inv(); }
constexpr modint& operator+=(const ll val) { return *this += modint{val}; }
constexpr modint& operator-=(const ll val) { return *this -= modint{val}; }
constexpr modint& operator*=(const ll val) { return *this *= modint{val}; }
constexpr modint& operator/=(const ll val) { return *this /= modint{val}; }
constexpr modint operator+(const modint& m) const { return modint{*this} += m; }
constexpr modint operator-(const modint& m) const { return modint{*this} -= m; }
constexpr modint operator*(const modint& m) const { return modint{*this} *= m; }
constexpr modint operator/(const modint& m) const { return modint{*this} /= m; }
constexpr modint operator+(const ll v) const { return *this + modint{v}; }
constexpr modint operator-(const ll v) const { return *this - modint{v}; }
constexpr modint operator*(const ll v) const { return *this * modint{v}; }
constexpr modint operator/(const ll v) const { return *this / modint{v}; }
constexpr bool operator==(const modint& m) const { return m_val == m(); }
constexpr bool operator!=(const modint& m) const { return not(*this == m); }
constexpr friend modint operator+(const ll v, const modint& m) { return modint{v} + m; }
constexpr friend modint operator-(const ll v, const modint& m) { return modint{v} - m; }
constexpr friend modint operator*(const ll v, const modint& m) { return modint{v} * m; }
constexpr friend modint operator/(const ll v, const modint& m) { return modint{v} / m; }
friend std::istream& operator>>(std::istream& is, modint& m)
{
ll v;
return is >> v, m = v, is;
}
friend std::ostream& operator<<(std::ostream& os, const modint& m) { return os << m(); }
constexpr uint operator()() const { return m_val; }
constexpr modint pow(ull n) const
{
modint ans = 1;
for (modint x = *this; n > 0; n >>= 1, x *= x) {
if (n & 1ULL) { ans *= x; }
}
return ans;
}
constexpr modint inv() const { return pow(mod - 2); }
modint sinv() const { return sinv(m_val); }
static modint fact(const uint n)
{
static std::vector<modint> fs{1, 1};
for (uint i = (uint)fs.size(); i <= n; i++) { fs.push_back(fs.back() * i); }
return fs[n];
}
static modint ifact(const uint n)
{
static std::vector<modint> ifs{1, 1};
for (uint i = (uint)ifs.size(); i <= n; i++) { ifs.push_back(ifs.back() * sinv(i)); }
return ifs[n];
}
static modint perm(const int n, const int k) { return k > n or k < 0 ? modint{0} : fact(n) * ifact(n - k); }
static modint comb(const int n, const int k) { return k > n or k < 0 ? modint{0} : fact(n) * ifact(n - k) * ifact(k); }
private:
static constexpr uint norm(const uint x) { return x < mod ? x : x - mod; }
static constexpr uint normll(const ll x) { return norm(uint(x % (ll)mod + (ll)mod)); }
static modint sinv(const uint n)
{
static std::vector<modint> is{1, 1};
for (uint i = (uint)is.size(); i <= n; i++) { is.push_back(-is[mod % i] * (mod / i)); }
return is[n];
}
uint m_val;
};
constexpr modinfo modinfo_1000000007 = {1000000007, 5, 1};
constexpr modinfo modinfo_998244353 = {998244353, 3, 23};
using modint_1000000007 = modint<modinfo_1000000007>;
using modint_998244353 = modint<modinfo_998244353>;
struct modinfo64
{
constexpr modinfo64() : mod{} {}
constexpr modinfo64(const ull mod_) : mod{mod_} {}
void set_mod(const ull mod_) { mod = mod_; }
ull mod;
};
template<const modinfo64& info>
class modint64
{
public:
static constexpr const ull& mod = info.mod;
constexpr modint64() : m_val{0} {}
constexpr modint64(const ll v) : m_val{normll(v)} {}
constexpr modint64(const modint64& m) = default;
constexpr void set_raw(const ull v) { m_val = v; }
constexpr modint64& operator=(const modint64& m) { return m_val = m(), (*this); }
constexpr modint64& operator=(const ll v) { return m_val = normll(v), (*this); }
constexpr modint64 operator+() const { return *this; }
constexpr modint64 operator-() const { return modint64{0} - (*this); }
constexpr modint64& operator+=(const modint64& m) { return m_val = norm(m_val + m()), *this; }
constexpr modint64& operator-=(const modint64& m) { return m_val = norm(m_val + mod - m()), *this; }
constexpr modint64& operator*=(const modint64& m) { return m_val = normll((LL)m_val * (LL)m() % (LL)mod), *this; }
constexpr modint64& operator/=(const modint64& m) { return *this *= m.inv(); }
constexpr modint64& operator+=(const ll val) { return *this += modint64{val}; }
constexpr modint64& operator-=(const ll val) { return *this -= modint64{val}; }
constexpr modint64& operator*=(const ll val) { return *this *= modint64{val}; }
constexpr modint64& operator/=(const ll val) { return *this /= modint64{val}; }
constexpr modint64 operator+(const modint64& m) const { return modint64{*this} += m; }
constexpr modint64 operator-(const modint64& m) const { return modint64{*this} -= m; }
constexpr modint64 operator*(const modint64& m) const { return modint64{*this} *= m; }
constexpr modint64 operator/(const modint64& m) const { return modint64{*this} /= m; }
constexpr modint64 operator+(const ll v) const { return *this + modint64{v}; }
constexpr modint64 operator-(const ll v) const { return *this - modint64{v}; }
constexpr modint64 operator*(const ll v) const { return *this * modint64{v}; }
constexpr modint64 operator/(const ll v) const { return *this / modint64{v}; }
constexpr bool operator==(const modint64& m) const { return m_val == m(); }
constexpr bool operator!=(const modint64& m) const { return not(*this == m); }
constexpr friend modint64 operator+(const ll v, const modint64& m) { return modint64{v} + m; }
constexpr friend modint64 operator-(const ll v, const modint64& m) { return modint64{v} - m; }
constexpr friend modint64 operator*(const ll v, const modint64& m) { return modint64{v} * m; }
constexpr friend modint64 operator/(const ll v, const modint64& m) { return modint64{v} / m; }
friend std::istream& operator>>(std::istream& is, modint64& m)
{
ll v;
return is >> v, m = v, is;
}
friend std::ostream& operator<<(std::ostream& os, const modint64& m) { return os << m(); }
constexpr ull operator()() const { return m_val; }
constexpr modint64 pow(ULL n) const
{
modint64 ans = 1;
for (modint64 x = *this; n > 0; n >>= 1, x *= x) {
if (n & ULL{1}) { ans *= x; }
}
return ans;
}
constexpr modint64 inv() const { return pow(mod - 2); }
modint64 sinv() const { return sinv(m_val); }
static modint64 fact(const uint n)
{
static std::vector<modint64> fs{1, 1};
for (uint i = (uint)fs.size(); i <= n; i++) { fs.push_back(fs.back() * i); }
return fs[n];
}
static modint64 ifact(const uint n)
{
static std::vector<modint64> ifs{1, 1};
for (uint i = (uint)ifs.size(); i <= n; i++) { ifs.push_back(ifs.back() * sinv(i)); }
return ifs[n];
}
static modint64 perm(const int n, const int k) { return k > n or k < 0 ? modint64{0} : fact(n) * ifact(n - k); }
static modint64 comb(const int n, const int k) { return k > n or k < 0 ? modint64{0} : fact(n) * ifact(n - k) * ifact(k); }
private:
static constexpr ull norm(const ull x) { return x < mod ? x : x - mod; }
static constexpr ull normll(const ll x) { return norm(ull(x % (ll)mod + (ll)mod)); }
static modint64 sinv(const uint n)
{
static std::vector<modint64> is{1, 1};
for (uint i = (uint)is.size(); i <= n; i++) { is.push_back(-is[mod % i] * (mod / i)); }
return is[n];
}
ull m_val;
};
class printer
{
public:
printer()
{
for (int i = 0; i < 10000; i++) {
for (int j = i, t = 3; t >= 0; t--, j /= 10) { m_dict[i * 4 + t] = (j % 10) + '0'; }
}
}
~printer() { flush(); }
template<typename... Args> int ln(const Args&... args) { return dump(args...), put_char('\n'), 0; }
template<typename... Args> int el(const Args&... args) { return dump(args...), put_char('\n'), flush(), 0; }
private:
using ll = long long;
using ull = unsigned long long;
static constexpr ull TEN(const int d) { return d == 0 ? 1ULL : TEN(d - 1) * 10ULL; }
void flush() { fwrite(m_memory, 1, m_tail, stdout), m_tail = 0; }
void put_char(const char c) { m_memory[m_tail++] = c; }
static constexpr int dn(const ull x)
{
return x < TEN(10)
? x < TEN(5)
? x < TEN(2)
? x < TEN(1) ? 1 : 2
: x < TEN(3) ? 3 : x < TEN(4) ? 4 : 5
: x < TEN(7)
? x < TEN(6) ? 6 : 7
: x < TEN(8) ? 8 : x < TEN(9) ? 9 : 10
: x < TEN(14)
? x < TEN(12)
? x < TEN(11) ? 11 : 12
: x < TEN(13) ? 13 : 14
: x < TEN(16)
? x < TEN(15) ? 15 : 16
: x < TEN(17) ? 17 : x < TEN(18) ? 18 : 19;
}
template<typename T>
void dump(T v)
{
if (C - m_tail < 50) { flush(); }
if (v < 0) { put_char('-'), v = -v; }
const auto d = dn(v);
int i = d - 4;
for (i = d - 4; i >= 0; i -= 4, v /= 10000) { memcpy(m_memory + m_tail + i, m_dict + (v % 10000) * 4, 4); }
memcpy(m_memory + m_tail, m_dict + v * 4 - i, i + 4);
m_tail += d;
}
template<typename T>
void dump(const std::vector<T>& vs)
{
for (int i = 0; i < (int)vs.size(); i++) {
if (i > 0) { put_char(' '); }
dump(vs[i]);
}
}
template<typename T>
void dump(const std::vector<std::vector<T>>& vss)
{
for (int i = 0; i < (int)vss.size(); i++) {
if (i > 0) { put_char('\n'); }
dump(vss[i]);
}
}
template<typename T, typename... Args> void dump(const T& v, const Args&... args) { return dump(v), put_char(' '), dump(args...), void(0); }
static constexpr int C = 1 << 18;
int m_tail = 0;
char m_memory[C];
char m_dict[10000 * 4];
} out;
class scanner
{
public:
scanner() {}
template<typename T>
T val()
{
if (m_tail - m_head < 40) { disk_read(); }
char c = get_char();
const bool neg = (c == '-');
if (neg) { c = get_char(); }
T ans = 0;
while (c >= '0') {
ans = ans * T{10} + (c - '0');
c = get_char();
}
return (neg ? -ans : ans);
}
template<typename T> T val(const T offset) { return val<T>() - offset; }
template<typename T> std::vector<T> vec(const int n)
{
return make_v<T>(n, [this]() { return val<T>(); });
}
template<typename T> std::vector<T> vec(const int n, const T offset)
{
return make_v<T>(n, [this, offset]() { return val<T>(offset); });
}
template<typename T> std::vector<std::vector<T>> vvec(const int n0, const int n1)
{
return make_v<std::vector<T>>(n0, [this, n1]() { return vec<T>(n1); });
}
template<typename T> std::vector<std::vector<T>> vvec(const int n0, const int n1, const T offset)
{
return make_v<std::vector<T>>(n0, [this, n1, offset]() { return vec<T>(n1, offset); });
}
template<typename... Args> auto tup() { return std::tuple<std::decay_t<Args>...>{val<Args>()...}; }
template<typename... Args> auto tup(const Args&... offsets) { return std::tuple<std::decay_t<Args>...>{val<Args>(offsets)...}; }
private:
template<typename T, typename F>
std::vector<T> make_v(const int n, F f)
{
std::vector<T> ans;
for (int i = 0; i < n; i++) { ans.push_back(f()); }
return ans;
}
char get_char() { return m_memory[m_head++]; }
void disk_read()
{
std::copy(m_memory + m_head, m_memory + m_tail, m_memory);
m_tail -= m_head, m_head = 0;
m_tail += fread(m_memory + m_tail, 1, C - m_tail, stdin);
}
static constexpr int C = 1 << 18;
int m_head = 0, m_tail = 0;
char m_memory[C];
} in;
constexpr int popcount(const ull v) { return v ? __builtin_popcountll(v) : 0; }
constexpr int log2p1(const ull v) { return v ? 64 - __builtin_clzll(v) : 0; }
constexpr int lsbp1(const ull v) { return __builtin_ffsll(v); }
constexpr int clog(const ull v) { return v ? log2p1(v - 1) : 0; }
constexpr ull ceil2(const ull v) { return 1ULL << clog(v); }
constexpr ull floor2(const ull v) { return v ? (1ULL << (log2p1(v) - 1)) : 0ULL; }
constexpr bool btest(const ull mask, const int ind) { return (mask >> ind) & 1ULL; }
template<typename Rng>
class rng_base
{
public:
using result_type = typename Rng::result_type;
static constexpr result_type min() { return Rng::min(); }
static constexpr result_type max() { return Rng::max(); }
rng_base() : rng_base(std::random_device{}()) {}
rng_base(const std::random_device::result_type seed) : m_rng(seed) {}
~rng_base() = default;
result_type operator()() { return m_rng(); }
result_type val(const result_type max = std::numeric_limits<result_type>::max())
{
if (max == std::numeric_limits<result_type>::max()) { return m_rng(); }
const result_type mask = ceil2(max + 1) - 1;
while (true) {
const result_type ans = m_rng() & mask;
if (ans <= max) { return ans; }
}
}
template<typename T> T val(const T min, const T max) { return min + T(val(max - min)); }
operator bool() { return val<bool>(0, 1); }
template<typename T> std::pair<T, T> pair(const T min, const T max) { return std::minmax(val<T>(min, max), val<T>(min, max)); }
template<typename T> std::vector<T> vec(const int n, const T min, const T max)
{
std::vector<T> vs(n);
for (auto& v : vs) { v = val<T>(min, max); }
return vs;
}
private:
Rng m_rng;
};
rng_base<std::mt19937> rng;
rng_base<std::mt19937_64> rng64;
modinfo info;
using mint = modint<info>;
modinfo64 info64;
using mint64 = modint64<info64>;
int main()
{
const int T = in.val<int>();
for (int t = 0; t < T; t++) {
const auto [P, K, A] = in.tup<ull, ull, ull>();
if (P < (1ULL << 31)) {
info.set_mod(P);
const mint ans = mod_nthroot(mint{(ll)A}, K);
if (ans.pow(K) == A) {
out.ln(ans());
} else {
out.ln(-1);
}
} else {
info64.set_mod(P);
const mint64 ans = mod_nthroot(mint64{(ll)A}, K);
if (ans.pow(K) == A) {
out.ln(ans());
} else {
out.ln(-1);
}
}
}
return 0;
}