結果

問題 No.981 一般冪乗根
ユーザー PachicobuePachicobue
提出日時 2020-12-31 03:41:48
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 11 ms / 6,000 ms
コード長 19,304 bytes
コンパイル時間 3,700 ms
コンパイル使用メモリ 233,376 KB
実行使用メモリ 10,880 KB
最終ジャッジ日時 2024-04-17 05:29:13
合計ジャッジ時間 48,758 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 5 ms
5,248 KB
testcase_01 AC 4 ms
5,376 KB
testcase_02 AC 4 ms
5,376 KB
testcase_03 AC 4 ms
5,376 KB
testcase_04 AC 3 ms
5,376 KB
testcase_05 AC 4 ms
5,376 KB
testcase_06 AC 4 ms
5,376 KB
testcase_07 AC 4 ms
5,376 KB
testcase_08 AC 4 ms
5,376 KB
testcase_09 AC 4 ms
5,376 KB
testcase_10 AC 4 ms
5,376 KB
testcase_11 AC 4 ms
5,376 KB
testcase_12 AC 6 ms
5,376 KB
testcase_13 AC 3 ms
5,376 KB
testcase_14 AC 4 ms
5,376 KB
testcase_15 AC 4 ms
5,376 KB
testcase_16 AC 4 ms
5,376 KB
testcase_17 AC 4 ms
5,376 KB
testcase_18 AC 3 ms
5,376 KB
testcase_19 AC 3 ms
5,376 KB
testcase_20 AC 4 ms
5,376 KB
testcase_21 AC 4 ms
5,376 KB
testcase_22 AC 4 ms
5,376 KB
testcase_23 AC 4 ms
5,376 KB
testcase_24 AC 4 ms
5,376 KB
testcase_25 AC 3 ms
5,376 KB
testcase_26 AC 2 ms
5,376 KB
testcase_27 AC 4 ms
5,376 KB
testcase_28 AC 11 ms
10,880 KB
evil_60bit1.txt AC 5 ms
5,376 KB
evil_60bit2.txt AC 5 ms
5,376 KB
evil_60bit3.txt AC 5 ms
5,376 KB
evil_hack AC 2 ms
5,376 KB
evil_hard_random AC 5 ms
5,376 KB
evil_hard_safeprime.txt AC 6 ms
5,376 KB
evil_hard_tonelli0 AC 4 ms
5,376 KB
evil_hard_tonelli1 AC 347 ms
20,096 KB
evil_hard_tonelli2 AC 42 ms
20,352 KB
evil_hard_tonelli3 AC 40 ms
5,376 KB
evil_sefeprime1.txt AC 4 ms
5,376 KB
evil_sefeprime2.txt AC 4 ms
5,376 KB
evil_sefeprime3.txt AC 5 ms
5,376 KB
evil_tonelli1.txt AC 862 ms
20,152 KB
evil_tonelli2.txt AC 702 ms
20,352 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using i32 = int;
using u32 = unsigned int;
using i64 = long long;
using u64 = unsigned long long;
using i128 = __int128_t;
using u128 = __uint128_t;
using f64 = double;
using f80 = long double;
using f128 = __float128;
using uint = unsigned int;
using ll = long long;
using ull = unsigned long long;
using ld = long double;
using LL = __int128_t;
using ULL = __uint128_t;
template<typename T> using max_heap = std::priority_queue<T>;
template<typename T> using min_heap = std::priority_queue<T, std::vector<T>, std::greater<T>>;
struct moddinfo
{
    u32 mod{1}, r2{1}, inv{1};
    void set_mod(const u32 nmod)
    {
        assert(nmod < (1ULL << 30)), assert(nmod & 1);
        mod = nmod, r2 = -u64(mod) % mod;
        inv = mod;
        for (int i = 0; i < 4; i++) { inv *= 2 - mod * inv; }
        assert(inv * mod == 1);
    }
    u32 reduce(const u64 x) const { return (x + u64(u32(x) * u32(-inv)) * mod) >> 32; }
};
template<const moddinfo& info>
class moddint
{
public:
    static constexpr const u32& mod = info.mod;
    constexpr moddint() : m_val{0} {}
    constexpr moddint(const i64 v) : m_val{info.reduce(u64(v % mod + mod) * info.r2)} {}
    constexpr moddint operator-() const { return moddint{} - (*this); }
    constexpr moddint& operator+=(const moddint& m)
    {
        if (i32(m_val += m.m_val - 2 * mod) < 0) { m_val += 2 * mod; }
        return *this;
    }
    constexpr moddint& operator-=(const moddint& m)
    {
        if (i32(m_val -= m.m_val) < 0) { m_val += 2 * mod; }
        return *this;
    }
    constexpr moddint& operator*=(const moddint& m) { return m_val = info.reduce(u64(m_val) * m.m_val), *this; }
    constexpr moddint& operator/=(const moddint& m) { return *this *= m.inv(); }
    constexpr moddint operator+(const moddint& m) const { return moddint{*this} += m; }
    constexpr moddint operator-(const moddint& m) const { return moddint{*this} -= m; }
    constexpr moddint operator*(const moddint& m) const { return moddint{*this} *= m; }
    constexpr moddint operator/(const moddint& m) const { return moddint{*this} /= m; }
    constexpr bool operator==(const moddint& m) const { return (m_val >= mod ? m_val - mod : m_val) == (m.m_val >= mod ? m.m_val - mod : m.m_val); }
    constexpr bool operator!=(const moddint& m) const { return not(*this == m); }
    friend std::istream& operator>>(std::istream& is, moddint& m)
    {
        i64 v;
        return is >> v, m = v, is;
    }
    friend std::ostream& operator<<(std::ostream& os, const moddint& m) { return os << m(); }
    constexpr u32 get() const
    {
        const u32 m = info.reduce(m_val);
        return m >= mod ? m - mod : m;
    }
    constexpr u32 operator()() const { return get(); }
    constexpr moddint pow(u64 n) const
    {
        moddint ans = 1;
        for (moddint x = *this; n > 0; n >>= 1, x *= x) {
            if (n & 1) { ans *= x; }
        }
        return ans;
    }
    constexpr moddint inv() const { return pow(mod - 2); }
private:
    u32 m_val;
};
struct moddinfo64
{
    u64 mod{1}, r2{1}, inv{1};
    void set_mod(const u64 nmod)
    {
        assert(nmod < (1ULL << 62)), assert(nmod & 1);
        mod = nmod, r2 = -u128(mod) % mod;
        inv = mod;
        for (int i = 0; i < 5; i++) { inv *= 2 - mod * inv; }
        assert(inv * mod == 1);
    }
    u64 reduce(const u128& x) const { return (x + u128((u64)x * -inv) * mod) >> 64; }
};
template<const moddinfo64& info>
class moddint64
{
public:
    static constexpr const u64& mod = info.mod;
    constexpr moddint64() : m_val{0} {}
    constexpr moddint64(const i64 v) : m_val{info.reduce((u128(v) + mod) * info.r2)} {}
    constexpr moddint64 operator-() const { return moddint64{} - (*this); }
    constexpr moddint64& operator+=(const moddint64& m)
    {
        if (i64(m_val += m.m_val - 2 * mod) < 0) { m_val += 2 * mod; }
        return *this;
    }
    constexpr moddint64& operator-=(const moddint64& m)
    {
        if (i64(m_val -= m.m_val) < 0) { m_val += 2 * mod; }
        return *this;
    }
    constexpr moddint64& operator*=(const moddint64& m) { return m_val = info.reduce(u128(m_val) * m.m_val), *this; }
    constexpr moddint64& operator/=(const moddint64& m) { return *this *= m.inv(); }
    constexpr moddint64 operator+(const moddint64& m) const { return moddint64{*this} += m; }
    constexpr moddint64 operator-(const moddint64& m) const { return moddint64{*this} -= m; }
    constexpr moddint64 operator*(const moddint64& m) const { return moddint64{*this} *= m; }
    constexpr moddint64 operator/(const moddint64& m) const { return moddint64{*this} /= m; }
    constexpr bool operator==(const moddint64& m) const { return (m_val >= mod ? m_val - mod : m_val) == (m.m_val >= mod ? m.m_val - mod : m.m_val); }
    constexpr bool operator!=(const moddint64& m) const { return not(*this == m); }
    friend std::istream& operator>>(std::istream& is, moddint64& m)
    {
        i64 v;
        return is >> v, m = v, is;
    }
    friend std::ostream& operator<<(std::ostream& os, const moddint64& m) { return os << m(); }
    constexpr u64 get() const
    {
        const u64 m = info.reduce(m_val);
        return m >= mod ? m - mod : m;
    }
    constexpr u64 operator()() const { return get(); }
    constexpr moddint64 pow(u128 n) const
    {
        moddint64 ans = 1;
        for (moddint64 x = *this; n > 0; n >>= 1, x *= x) {
            if (n & 1) { ans *= x; }
        }
        return ans;
    }
    constexpr moddint64 inv() const { return pow(mod - 2); }
private:
    u64 m_val;
};
template<typename K, typename V, int LG = 20>
class hashmap
{
public:
    hashmap() = default;
    V& operator[](const K& k)
    {
        for (uint i = hash(k);; (i += 1) &= (N - 1)) {
            if (not m_used.test(i)) {
                m_keys[i] = k, m_used.set(i);
                return m_vals[i] = V{};
            }
            if (m_keys[i] == k) { return m_vals[i]; }
        }
    }
    void erase(const K& k) const
    {
        uint i = 0;
        for (i = hash(k); m_used.test(i) and m_keys[i] != k; (i += 1) &= (N - 1)) {}
        if (m_used.test(i) and m_keys[i] == k) { m_used.reset(i); }
    }
    int count(const K& k) const
    {
        uint i = 0;
        for (i = hash(k); m_used.test(i) and m_keys[i] != k; (i += 1) &= (N - 1)) {}
        return m_used.test(i) and m_keys[i] == k;
    }
private:
    static constexpr int N = 1 << LG;
    static constexpr ull r = 11995408973635179863ULL;
    static constexpr uint hash(const ull a) { return (a * r) >> (64 - LG); }
    std::bitset<N> m_used;
    K m_keys[N];
    V m_vals[N];
};
template<typename T> constexpr T inverse(const T a, const T mod) { return a == 1 ? T{1} : ((a - inverse(mod % a, a)) * mod + 1) / a; }
template<typename T> constexpr std::pair<T, T> extgcd(const T a, const T b)
{
    if (a == 0) { return -1 / b; }
    if (b == 0) { return 1 / a; }
    const T x = inverse(a, b), y = (a * x - 1) / b;
    return {x, y};
}
constexpr int popcount(const ull v) { return v ? __builtin_popcountll(v) : 0; }
constexpr int log2p1(const ull v) { return v ? 64 - __builtin_clzll(v) : 0; }
constexpr int lsbp1(const ull v) { return __builtin_ffsll(v); }
constexpr int clog(const ull v) { return v ? log2p1(v - 1) : 0; }
constexpr ull ceil2(const ull v) { return 1ULL << clog(v); }
constexpr ull floor2(const ull v) { return v ? (1ULL << (log2p1(v) - 1)) : 0ULL; }
constexpr bool ispow2(const ull v) { return (v & (v - 1)) == 0; }
constexpr bool btest(const ull mask, const int ind) { return (mask >> ind) & 1ULL; }
template<typename Rng>
class rng_base
{
public:
    using result_type = typename Rng::result_type;
    static constexpr result_type min() { return Rng::min(); }
    static constexpr result_type max() { return Rng::max(); }
    rng_base() : rng_base(std::random_device{}()) {}
    rng_base(const std::random_device::result_type seed) : m_rng(seed) {}
    ~rng_base() = default;
    result_type operator()() { return m_rng(); }
    result_type val(const result_type max = std::numeric_limits<result_type>::max())
    {
        if (max == std::numeric_limits<result_type>::max()) { return m_rng(); }
        const result_type mask = ceil2(max + 1) - 1;
        while (true) {
            const result_type ans = m_rng() & mask;
            if (ans <= max) { return ans; }
        }
    }
    template<typename T> T val(const T min, const T max) { return min + T(val(max - min)); }
    operator bool() { return val<bool>(0, 1); }
    template<typename T> std::pair<T, T> pair(const T min, const T max) { return std::minmax(val<T>(min, max), val<T>(min, max)); }
    template<typename T> std::vector<T> vec(const int n, const T min, const T max)
    {
        std::vector<T> vs(n);
        for (auto& v : vs) { v = val<T>(min, max); }
        return vs;
    }
private:
    Rng m_rng;
};
rng_base<std::mt19937> rng;
rng_base<std::mt19937_64> rng64;
template<typename mint>
inline bool miller_rabin(const u64 n, const std::vector<u64>& as)
{
    auto d = n - 1;
    for (; (d & 1) == 0; d >>= 1) {}
    for (const u64 a : as) {
        if (n <= a) { break; }
        auto s = d;
        mint x = mint(a).pow(s);
        while (x.get() != 1 and x.get() != n - 1 and s != n - 1) { x *= x, s <<= 1; }
        if (x.get() != n - 1 and s % 2 == 0) { return false; }
    }
    return true;
}
inline bool is_prime(const u64 n)
{
    static moddinfo info;
    static moddinfo64 info64;
    using mint = moddint<info>;
    using mint64 = moddint64<info64>;
    if (n == 1) { return false; }
    if ((n & 1) == 0) { return n == 2; }
    if (n < (1ULL << 30)) {
        info.set_mod(n);
        return miller_rabin<mint>(n, {2, 7, 61});
    } else {
        info64.set_mod(n);
        return miller_rabin<mint64>(n, {2, 325, 9375, 28178, 450775, 9780504});
    }
}
template<typename mint>
u64 pollard_rho(const u64 n)
{
    if (n % 2 == 0) { return 2; }
    if (is_prime(n)) { return n; }
    mint c;
    auto f = [&](const mint x) { return x * x + c; };
    while (true) {
        mint x, y, ys, q = 1;
        y = rng.val<u64>(0, n - 2) + 2, c = rng.val<u64>(0, n - 2) + 2;
        u64 d = 1;
        constexpr u32 dk = 128;
        for (u32 r = 1; d == 1; r <<= 1) {
            x = y;
            for (u32 i = 0; i < r; i++) { y = f(y); }
            for (u32 k = 0; k < r and d == 1; k += dk) {
                ys = y;
                for (u32 i = 0; i < dk and i < r - k; i++) { q *= x - (y = f(y)); }
                d = std::gcd((u64)q.get(), n);
            }
        }
        if (d == n) {
            do {
                d = std::gcd(u64((x - (ys = f(ys))).get()), n);
            } while (d == 1);
        }
        if (d != n) { return d; }
    }
    return n;
}
std::map<u64, int> prime_factors(const u64 n)
{
    static moddinfo info;
    static moddinfo64 info64;
    using mint = moddint<info>;
    using mint64 = moddint64<info64>;
    std::map<u64, int> ans;
    auto fac = [&](auto self, ull x) -> void {
        while ((x & 1) == 0) { x >>= 1, ans[2]++; }
        if (x == 1) { return; }
        u64 p;
        if (x < (1ULL << 30)) {
            info.set_mod(x);
            p = pollard_rho<mint>(x);
        } else {
            info64.set_mod(x);
            p = pollard_rho<mint64>(x);
        }
        if (p == x) { return ans[p]++, void(0); }
        self(self, p), self(self, x / p);
    };
    return fac(fac, n), ans;
}
template<typename mint>
mint mod_nthroot(mint A, ull k)
{
    const ull P = mint::mod;
    if (A == 0) { return 0; }
    if (k == 0) { return A; }
    const ull g = std::gcd(P - 1, k);
    if (A.pow((P - 1) / g)() != 1) { return 0; }
    A = A.pow(inverse<ULL>(k / g, (P - 1) / g));
    if (g == 1) { return A; }
    const auto fs = prime_factors(g);
    for (const auto& [p, e] : fs) {
        ull pe = 1;
        for (int i = 0; i < e; i++) { pe *= p; }
        ull q = P - 1, Q = 0;
        while (q % p == 0) { q /= p, Q++; }
        const ull y = pe - inverse<ULL>(q, pe), z = ((ULL)y * q + 1) / (ULL)pe;
        mint X = A.pow(z);
        if ((int)Q == e) {
            A = X;
            continue;
        }
        mint Eraser = 1;
        const ull h = (P - 1) / p;
        for (mint Z = 2;; Z += 1) {
            if (Z.pow(h) != 1) {
                Eraser = Z.pow(q);
                break;
            }
        }
        mint Error = A.pow((ULL)y * q);
        mint pEraser = Eraser;
        for (ull i = 0; i < Q - 1; i++) { pEraser = pEraser.pow(p); }
        const mint ipEraser = pEraser.inv();
        hashmap<ull, ull> memo;
        {
            const ull M = std::max(1ULL, (ull)(std::sqrt(p) * std::sqrt(Q - e))), B = std::max(1ULL, ((ull)p - 1) / M);
            const mint ppEraser = pEraser.pow(B);
            mint prod = 1;
            for (ull i = 0; i < p; i += B, prod *= ppEraser) { memo[prod()] = i; }
        }
        while (Error() != 1) {
            ull l = 0;
            mint pError = Error;
            for (ull i = 0; i < Q; i++) {
                const auto npError = pError.pow(p);
                if (npError == 1) {
                    l = Q - (i + 1);
                    break;
                }
                pError = npError;
            }
            ull c = -1;
            {
                mint small = pError.inv();
                for (ull j = 0;; j++, small *= ipEraser) {
                    if (memo.count(small())) {
                        const ull i = memo[small()];
                        c = i + j;
                        break;
                    }
                }
            }
            auto pEraser2 = Eraser.pow(c);
            for (ull i = 0; i < l - e; i++) { pEraser2 = pEraser2.pow(p); }
            X *= pEraser2, Error *= pEraser2.pow(pe);
        }
        A = X;
    }
    return A;
}
class printer
{
public:
    printer()
    {
        for (int i = 0; i < 10000; i++) {
            for (int j = i, t = 3; t >= 0; t--, j /= 10) { m_dict[i * 4 + t] = (j % 10) + '0'; }
        }
    }
    ~printer() { flush(); }
    template<typename... Args> int ln(const Args&... args) { return dump(args...), put_char('\n'), 0; }
    template<typename... Args> int el(const Args&... args) { return dump(args...), put_char('\n'), flush(), 0; }
private:
    using ll = long long;
    using ull = unsigned long long;
    static constexpr ull TEN(const int d) { return d == 0 ? 1ULL : TEN(d - 1) * 10ULL; }
    void flush() { fwrite(m_memory, 1, m_tail, stdout), m_tail = 0; }
    void put_char(const char c) { m_memory[m_tail++] = c; }
    static constexpr int dn(const ull x)
    {
        return x < TEN(10)
                   ? x < TEN(5)
                         ? x < TEN(2)
                               ? x < TEN(1) ? 1 : 2
                               : x < TEN(3) ? 3 : x < TEN(4) ? 4 : 5
                         : x < TEN(7)
                               ? x < TEN(6) ? 6 : 7
                               : x < TEN(8) ? 8 : x < TEN(9) ? 9 : 10
                   : x < TEN(14)
                         ? x < TEN(12)
                               ? x < TEN(11) ? 11 : 12
                               : x < TEN(13) ? 13 : 14
                         : x < TEN(16)
                               ? x < TEN(15) ? 15 : 16
                               : x < TEN(17) ? 17 : x < TEN(18) ? 18 : 19;
    }
    void dump(const char* s) { dump(std::string{s}); }
    void dump(const std::string& s)
    {
        for (const char c : s) { put_char(c); }
    }
    template<typename T>
    void dump(T v)
    {
        if (C - m_tail < 50) { flush(); }
        if (v < 0) { put_char('-'), v = -v; }
        const auto d = dn(v);
        int i = d - 4;
        for (i = d - 4; i >= 0; i -= 4, v /= 10000) { memcpy(m_memory + m_tail + i, m_dict + (v % 10000) * 4, 4); }
        memcpy(m_memory + m_tail, m_dict + v * 4 - i, i + 4);
        m_tail += d;
    }
    template<typename T>
    void dump(const std::vector<T>& vs)
    {
        for (int i = 0; i < (int)vs.size(); i++) {
            if (i > 0) { put_char(' '); }
            dump(vs[i]);
        }
    }
    template<typename T>
    void dump(const std::vector<std::vector<T>>& vss)
    {
        for (int i = 0; i < (int)vss.size(); i++) {
            if (i > 0) { put_char('\n'); }
            dump(vss[i]);
        }
    }
    template<typename T, typename... Args> void dump(const T& v, const Args&... args) { return dump(v), put_char(' '), dump(args...), void(0); }
    static constexpr int C = 1 << 18;
    int m_tail = 0;
    char m_memory[C];
    char m_dict[10000 * 4];
} out;
class scanner
{
public:
    scanner() {}
    template<typename T>
    T val()
    {
        if (m_tail - m_head < 40) { disk_read(); }
        char c = get_char();
        const bool neg = (c == '-');
        if (neg) { c = get_char(); }
        T ans = 0;
        while (c >= '0') {
            ans = ans * T{10} + (c - '0');
            c = get_char();
        }
        return (neg ? -ans : ans);
    }
    template<typename T> T val(const T offset) { return val<T>() - offset; }
    template<typename T> std::vector<T> vec(const int n)
    {
        return make_v<T>(n, [this]() { return val<T>(); });
    }
    template<typename T> std::vector<T> vec(const int n, const T offset)
    {
        return make_v<T>(n, [this, offset]() { return val<T>(offset); });
    }
    template<typename T> std::vector<std::vector<T>> vvec(const int n0, const int n1)
    {
        return make_v<std::vector<T>>(n0, [this, n1]() { return vec<T>(n1); });
    }
    template<typename T> std::vector<std::vector<T>> vvec(const int n0, const int n1, const T offset)
    {
        return make_v<std::vector<T>>(n0, [this, n1, offset]() { return vec<T>(n1, offset); });
    }
    template<typename... Args> auto tup() { return std::tuple<std::decay_t<Args>...>{val<Args>()...}; }
    template<typename... Args> auto tup(const Args&... offsets) { return std::tuple<std::decay_t<Args>...>{val<Args>(offsets)...}; }
private:
    template<typename T, typename F>
    std::vector<T> make_v(const int n, F f)
    {
        std::vector<T> ans;
        for (int i = 0; i < n; i++) { ans.push_back(f()); }
        return ans;
    }
    char get_char() { return m_memory[m_head++]; }
    void disk_read()
    {
        std::copy(m_memory + m_head, m_memory + m_tail, m_memory);
        m_tail -= m_head, m_head = 0;
        m_tail += fread(m_memory + m_tail, 1, C - m_tail, stdin);
    }
    static constexpr int C = 1 << 18;
    int m_head = 0, m_tail = 0;
    char m_memory[C];
} in;
moddinfo info;
using mint = moddint<info>;
moddinfo64 info64;
using mint64 = moddint64<info64>;
int main()
{
    const int T = in.val<int>();
    for (int t = 0; t < T; t++) {
        const auto [P, K, A] = in.tup<u64, u64, u64>();
        if (P == 2) {
            out.ln(A == 1 ? 1 : K == 0 ? -1 : 0);
            continue;
        }
        if (P < (1ULL << 30)) {
            info.set_mod(P);
            const mint ans = mod_nthroot(mint(A), K);
            if (ans.pow(K) == A) {
                out.ln(ans());
            } else {
                out.ln(-1);
            }
        } else {
            info64.set_mod(P);
            const mint64 ans = mod_nthroot(mint64(A), K);
            if (ans.pow(K) == A) {
                out.ln(ans());
            } else {
                out.ln(-1);
            }
        }
    }
    return 0;
}
0