結果
問題 | No.2506 Sum of Weighted Powers |
ユーザー |
|
提出日時 | 2023-03-22 02:47:06 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 4,523 bytes |
コンパイル時間 | 2,892 ms |
コンパイル使用メモリ | 141,708 KB |
最終ジャッジ日時 | 2025-02-11 16:06:43 |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 15 WA * 27 |
ソースコード
#include <iostream>#include <cassert>#include <cmath>#include <map>long long pow(long long a, long long n, long long p) {long long r = 1;for (;n > 0;n >>= 1, a = a * a % p)if (n % 2 == 1)r = r * a % p;return r;}int cnt(long long a, long long base, long long p) {int ret = 0;while (a != 1) {a = pow(a, base, p);++ret;}return ret;}long long inv(long long a, long long p) {a %= p;long long u = 1, v = 0;long long b = p;while (b > 0) {long long q = a / b;a %= b;u -= v * q % p;u = (u % p + p) % p;{u ^= v;v ^= u;u ^= v;a ^= b;b ^= a;a ^= b;}}return u < 0 ? u + p : u;}long long gcd(long long a, long long b) {return a == 0 ? b : gcd(b % a, a);}long long peth_root(long long a, long long p, int e, long long mod) {long long q = mod - 1;int s = 0;while (q % p == 0) {q /= p;++s;}long long pe = pow(p, e, mod);long long ans = pow(a, ((pe - 1) * inv(q, pe) % pe * q + 1) / pe, mod);long long c = 2;while (pow(c, (mod - 1) / p, mod) == 1)++c;c = pow(c, q, mod);std::map<long long, int> map;long long add = 1;int v = (int) std::sqrt((double) (s - e) * p) + 1;long long mul = pow(c, v * pow(p, s - 1, mod - 1) % (mod - 1), mod);for (int i = 0;i <= v;++i) {map[add] = i;add = add * mul % mod;}mul = inv(pow(c, pow(p, s - 1, mod - 1), mod), mod);for (int i = e;i < s;++i) {long long err = inv(pow(ans, pe, mod), mod) * a % mod;long long target = pow(err, pow(p, s - 1 - i, mod - 1), mod);for (int j = 0; j <= v; ++j) {if (map.find(target) != map.end()) {int x = map[target];ans = ans * pow(c, (j + v * x) * pow(p, i - e, mod - 1) % (mod - 1), mod) % mod;break;}target = target * mul % mod;assert(j != v);}}return ans;}long long kth_root(long long a, long long k, long long p) {if (k > 0 && a == 0)return 0;k %= p - 1;long long g = gcd(k, p - 1);if (pow(a, (p - 1) / g, p) != 1)return -1;a = pow(a, inv(k / g, (p - 1) / g), p);for (long long div = 2;div * div <= g;++div) {int sz = 0;while (g % div == 0) {g /= div;++sz;}if (sz > 0) {long long b = peth_root(a, div, sz, p);a = b;}}if (g > 1)a = peth_root(a, g, 1, p);return a;}#include <atcoder/modint>using mint = atcoder::static_modint<943718401>;namespace atcoder {std::istream& operator>>(std::istream& in, mint &a) {long long e; in >> e; a = e;return in;}std::ostream& operator<<(std::ostream& out, const mint &a) {out << a.val();return out;}} // namespace atcoder#include <atcoder/convolution>mint solve(const int n, const mint x, const std::vector<mint> &a, const std::vector<mint> &b, const std::vector<mint> &c) {if (x == 0) {mint ans = 0;for (int i = 0; i <= n; ++i) {ans += a[i] * b[i] * c[0];}for (int i = 1; i <= n; ++i) {ans += a[i] * b[0] * c[i];}return ans;}int cbrt_x_ = kth_root(x.val(), 3, mint::mod());if (cbrt_x_ == -1) {std::cout << -1 << std::endl;exit(0);}const mint cbrt_x = cbrt_x_, inv_cbrt_x = cbrt_x.inv();auto t = [&](long long k) {return k * k * k;};std::vector<mint> f(n + 1), g(n + 1);for (int i = 0; i <= n; ++i) {const mint pow_inv_x = inv_cbrt_x.pow(t(i));f[i] = b[i] * pow_inv_x;g[i] = c[i] * pow_inv_x;}const std::vector<mint> h = atcoder::convolution(f, g);mint ans = 0;for (int i = 0; i <= n; ++i) {const mint pow_x = cbrt_x.pow(t(i));ans += a[i] * pow_x * h[i];}return ans;}int main() {std::ios::sync_with_stdio(false);std::cin.tie(nullptr);int n, x;std::cin >> n >> x;std::vector<mint> a(n + 1), b(n + 1), c(n + 1);for (int i = 0, v; i <= n; ++i) std::cin >> v, a[i] = v;for (int i = 0, v; i <= n; ++i) std::cin >> v, b[i] = v;for (int i = 0, v; i <= n; ++i) std::cin >> v, c[i] = v;std::cout << solve(n, x, a, b, c).val() << std::endl;return 0;}