結果
| 問題 |
No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
|
| コンテスト | |
| ユーザー |
KoD
|
| 提出日時 | 2020-05-29 22:09:36 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 510 ms / 3,500 ms |
| コード長 | 17,078 bytes |
| コンパイル時間 | 3,312 ms |
| コンパイル使用メモリ | 208,932 KB |
| 最終ジャッジ日時 | 2025-01-10 17:28:11 |
|
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 29 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
/**
* @brief 高速入出力
* @author えびちゃん
* @see https://qiita.com/rsk0315_h4x/items/17a9cb12e0de5fd918f4
*/
namespace fast {
static constexpr size_t buf_size = 1 << 17;
static constexpr size_t margin = 1;
static char inbuf[buf_size + margin] = {};
static __attribute__((aligned(8))) char outbuf[buf_size + margin] = {};
static __attribute__((aligned(8))) char minibuf[32];
static constexpr size_t int_digits = 20; // 18446744073709551615
static constexpr uintmax_t digit_mask = 0x3030303030303030;
static constexpr uintmax_t first_mask = 0x00FF00FF00FF00FF;
static constexpr uintmax_t second_mask = 0x0000FFFF0000FFFF;
static constexpr uintmax_t third_mask = 0x00000000FFFFFFFF;
static constexpr uintmax_t tenpow[] = {
1,
10,
100,
1000,
10000,
100000,
1000000,
10000000,
100000000,
1000000000,
10000000000,
100000000000,
1000000000000,
10000000000000,
100000000000000,
1000000000000000,
10000000000000000,
100000000000000000,
1000000000000000000,
10000000000000000000u,
};
static __attribute__((
aligned(8))) char inttab[40000] = {}; // 4-digit integers (10000 many)
static char S_sep = ' ', S_end = '\n';
template <typename Tp>
using enable_if_integral = std::enable_if<std::is_integral<Tp>::value, Tp>;
class scanner {
char *pos = inbuf;
char *endpos = inbuf + buf_size;
void M_read_from_stdin() { endpos = inbuf + fread(pos, 1, buf_size, stdin); }
void M_reread_from_stdin() {
ptrdiff_t len = endpos - pos;
if (!(inbuf + len <= pos)) return;
memcpy(inbuf, pos, len);
char *tmp = inbuf + len;
endpos = tmp + fread(tmp, 1, buf_size - len, stdin);
*endpos = 0;
pos = inbuf;
}
public:
scanner() { M_read_from_stdin(); }
template <typename Integral,
typename enable_if_integral<Integral>::type * = nullptr>
void scan_parallel(Integral &x) {
if (__builtin_expect(endpos <= pos + int_digits, 0)) M_reread_from_stdin();
bool ends = false;
typename std::make_unsigned<Integral>::type y = 0;
bool neg = false;
if (std::is_signed<Integral>::value && *pos == '-') {
neg = true;
++pos;
}
do {
memcpy(minibuf, pos, 8);
long c = *(long *)minibuf;
long d = (c & digit_mask) ^ digit_mask;
int skip = 8;
int shift = 8;
if (d) {
int ctz = __builtin_ctzl(d);
if (ctz == 4) break;
c &= (1L << (ctz - 5)) - 1;
int discarded = (68 - ctz) / 8;
shift -= discarded;
c <<= discarded * 8;
skip -= discarded;
ends = true;
}
c |= digit_mask;
c ^= digit_mask;
c = ((c >> 8) + c * 10) & first_mask;
c = ((c >> 16) + c * 100) & second_mask;
c = ((c >> 32) + c * 10000) & third_mask;
y = y * tenpow[shift] + c;
pos += skip;
} while (!ends);
x = (neg ? -y : y);
++pos;
}
template <typename Integral,
typename enable_if_integral<Integral>::type * = nullptr>
void scan_serial(Integral &x) {
if (__builtin_expect(endpos <= pos + int_digits, 0)) M_reread_from_stdin();
bool neg = false;
if (std::is_signed<Integral>::value && *pos == '-') {
neg = true;
++pos;
}
typename std::make_unsigned<Integral>::type y = *pos - '0';
while (*++pos >= '0') y = 10 * y + *pos - '0';
x = (neg ? -y : y);
++pos;
}
template <typename Integral,
typename enable_if_integral<Integral>::type * = nullptr>
// Use scan_parallel(x) only when x may be too large (about 10^12).
// Otherwise, even when x <= 10^9, scan_serial(x) may be faster.
void scan(Integral &x) {
scan_parallel(x);
}
void scan_serial(std::string &s) {
// until first whitespace
s = "";
do {
char *startpos = pos;
while (*pos > ' ') ++pos;
s += std::string(startpos, pos);
if (*pos != 0) {
++pos; // skip the space
break;
}
M_reread_from_stdin();
} while (true);
}
void scan(std::string &s) { scan_serial(s); }
template <typename Tp, typename... Args>
void scan(Tp &x, Args &&... xs) {
scan(x);
scan(std::forward<Args>(xs)...);
}
};
class printer {
char *pos = outbuf;
void M_flush_stdout() {
fwrite(outbuf, 1, pos - outbuf, stdout);
pos = outbuf;
}
static int S_int_digits(uintmax_t n) {
if (n < tenpow[16]) { // 1
if (n < tenpow[8]) { // 2
if (n < tenpow[4]) { // 3
if (n < tenpow[2]) { // 4
if (n < tenpow[1]) return 1; // 5
return 2; // 5
}
if (n < tenpow[3]) return 3; // 4
return 4; // 4
}
if (n < tenpow[6]) { // 4
if (n < tenpow[5]) return 5; // 5
return 6; // 5
}
if (n < tenpow[7]) return 7; // 5
return 8; // 5
}
if (n < tenpow[12]) { // 3
if (n < tenpow[10]) { // 4
if (n < tenpow[9]) return 9; // 5
return 10; // 5
}
if (n < tenpow[11]) return 11; // 5
return 12; // 5
}
if (n < tenpow[14]) { // 4
if (n < tenpow[13]) return 13; // 5
return 14; // 5
}
if (n < tenpow[15]) return 15; // 5
return 16; // 5
}
if (n < tenpow[18]) { // 2
if (n < tenpow[17]) return 17; // 3
return 18; // 3
}
return 19; // 2
// if (n < tenpow[19]) return 19; // 3
// return 20; // 3
}
void M_precompute() {
unsigned long const digit1 = 0x0200000002000000;
unsigned long const digit2 = 0xf600fffff6010000;
unsigned long const digit3 = 0xfff600fffff60100;
unsigned long const digit4 = 0xfffff600fffff601;
unsigned long counter = 0x3130303030303030;
for (int i = 0, i4 = 0; i4 < 10; ++i4, counter += digit4)
for (int i3 = 0; i3 < 10; ++i3, counter += digit3)
for (int i2 = 0; i2 < 10; ++i2, counter += digit2)
for (int i1 = 0; i1 < 5; ++i1, ++i, counter += digit1)
*((unsigned long *)inttab + i) = counter;
}
public:
printer() { M_precompute(); }
~printer() { M_flush_stdout(); }
void print(char c) {
if (__builtin_expect(pos + 1 >= outbuf + buf_size, 0)) M_flush_stdout();
*pos++ = c;
}
template <size_t N>
void print(char const (&s)[N]) {
if (__builtin_expect(pos + N >= outbuf + buf_size, 0)) M_flush_stdout();
memcpy(pos, s, N - 1);
pos += N - 1;
}
void print(char const *s) {
// FIXME: strlen や memcpy などで定数倍高速化したい
while (*s != 0) {
*pos++ = *s++;
if (pos == outbuf + buf_size) M_flush_stdout();
}
}
void print(std::string const &s) { print(s.data()); }
template <typename Integral,
typename enable_if_integral<Integral>::type * = nullptr>
void print(Integral x) {
if (__builtin_expect(pos + int_digits >= outbuf + buf_size, 0))
M_flush_stdout();
if (x == 0) {
*pos++ = '0';
return;
}
if (x < 0) {
*pos++ = '-';
if (__builtin_expect(x == std::numeric_limits<Integral>::min(), 0)) {
switch (sizeof x) {
case 2:
print("32768");
return;
case 4:
print("2147483648");
return;
case 8:
print("9223372036854775808");
return;
}
}
x = -x;
}
typename std::make_unsigned<Integral>::type y = x;
int len = S_int_digits(y);
pos += len;
char *tmp = pos;
int w = (pos - outbuf) & 3;
if (w > len) w = len;
for (int i = w; i > 0; --i) {
*--tmp = y % 10 + '0';
y /= 10;
}
len -= w;
while (len >= 4) {
tmp -= 4;
*(unsigned *)tmp = *((unsigned *)inttab + (y % 10000));
len -= 4;
if (len) y /= 10000;
}
while (len-- > 0) {
*--tmp = y % 10 + '0';
y /= 10;
}
}
template <typename Tp, typename... Args>
void print(Tp const &x, Args &&... xs) {
if (sizeof...(Args) > 0) {
print(x);
print(S_sep);
print(std::forward<Args>(xs)...);
}
}
template <typename Tp>
void println(Tp const &x) {
print(x), print(S_end);
}
template <typename Tp, typename... Args>
void println(Tp const &x, Args &&... xs) {
print(x, std::forward<Args>(xs)...);
print(S_end);
}
static void set_sep(char c) { S_sep = c; }
static void set_end(char c) { S_end = c; }
};
} // namespace fast
fast::scanner fastin;
fast::printer fastout;
static constexpr uint32_t get_r(int mod) {
uint64_t ret = 1, m = mod, n = mod - 2;
while (n) {
ret = uint32_t(ret * m);
m = uint32_t(m * m);
n >>= 1;
}
return ret;
};
template <uint32_t mod>
struct LazyMontgomeryModInt {
using mint = LazyMontgomeryModInt;
using i32 = int32_t;
using u32 = uint32_t;
using u64 = uint64_t;
static constexpr u32 r = get_r(mod);
static constexpr u32 n2 = -u64(mod) % mod;
static_assert(r * mod == 1, "invalid, r * mod != 1");
static_assert(mod < (1 << 30), "invalid, mod >= 2 ^ 30");
static_assert((mod & 1) == 1, "invalid, mod % 2 == 0");
u32 a;
LazyMontgomeryModInt() : a(0) {}
LazyMontgomeryModInt(const int64_t &b) : a(reduce(u64(b % mod + mod) * n2)){};
static u32 reduce(const u64 &b) {
return u32(b >> 32) + mod - u32((u64(u32(b) * r) * mod) >> 32);
}
mint &operator+=(const mint &b) {
if (i32(a += b.a - 2 * mod) < 0) a += 2 * mod;
return *this;
}
mint &operator-=(const mint &b) {
if (i32(a -= b.a) < 0) a += 2 * mod;
return *this;
}
mint &operator*=(const mint &b) {
a = reduce(u64(a) * b.a);
return *this;
}
mint &operator/=(const mint &b) {
*this *= b.inverse();
return *this;
}
mint operator+(const mint &b) const { return mint(*this) += b; }
mint operator-(const mint &b) const { return mint(*this) -= b; }
mint operator*(const mint &b) const { return mint(*this) *= b; }
mint operator/(const mint &b) const { return mint(*this) /= b; }
u32 get() const {
u32 ret = reduce(a);
return ret >= mod ? ret - mod : ret;
}
mint pow(u64 n) const {
mint ret(1), mul(*this);
while (n > 0) {
if (n & 1) ret *= mul;
mul *= mul;
n >>= 1;
}
return ret;
}
friend ostream &operator<<(ostream &os, const mint &b) {
return os << b.get();
}
friend istream &operator>>(istream &is, mint &b) {
int64_t t;
is >> t;
b = LazyMontgomeryModInt<mod>(t);
return (is);
}
mint inverse() const { return pow(mod - 2); }
static constexpr u32 get_mod() { return mod; }
};
static constexpr uint32_t get_pr(uint32_t mod) {
using u64 = uint64_t;
u64 ds[32] = {};
int idx = 0;
u64 m = mod - 1;
for (u64 i = 2; i * i <= m; ++i) {
if (m % i == 0) {
ds[idx++] = i;
while (m % i == 0) m /= i;
}
}
if (m != 1) ds[idx++] = m;
uint32_t pr = 2;
while (1) {
int flg = 1;
for (int i = 0; i < idx; ++i) {
u64 a = pr, b = (mod - 1) / ds[i], r = 1;
while (b) {
if (b & 1) r = r * a % mod;
a = a * a % mod;
b >>= 1;
}
if (r == 1) {
flg = 0;
break;
}
}
if (flg == 1) break;
++pr;
}
return pr;
};
template <typename mint>
struct NTT {
static constexpr uint32_t mod = mint::get_mod();
static constexpr uint32_t pr = get_pr(mod);
static constexpr int level = __builtin_ctzll(mod - 1);
mint dw[level], dy[level];
void setwy(int k) {
mint w[level], y[level];
w[k - 1] = mint(pr).pow((mod - 1) / (1 << k));
y[k - 1] = w[k - 1].inverse();
for (int i = k - 2; i > 0; --i)
w[i] = w[i + 1] * w[i + 1], y[i] = y[i + 1] * y[i + 1];
dw[1] = w[1], dy[1] = y[1], dw[2] = w[2], dy[2] = y[2];
for (int i = 3; i < k; ++i) {
dw[i] = dw[i - 1] * y[i - 2] * w[i];
dy[i] = dy[i - 1] * w[i - 2] * y[i];
}
}
void fft4(vector<mint> &a, int k) {
if (k & 1) {
int v = 1 << (k - 1);
for (int j = 0; j < v; ++j) {
mint ajv = a[j + v];
a[j + v] = a[j] - ajv;
a[j] += ajv;
}
}
int u = 1 << (2 + (k & 1));
int v = 1 << (k - 2 - (k & 1));
mint one = mint(1);
mint imag = dw[1];
while (v) {
// jh = 0
{
int j0 = 0;
int j1 = v;
int j2 = j1 + v;
int j3 = j2 + v;
for (; j0 < v; ++j0, ++j1, ++j2, ++j3) {
mint t0 = a[j0], t1 = a[j1], t2 = a[j2], t3 = a[j3];
mint t0p2 = t0 + t2, t1p3 = t1 + t3;
mint t0m2 = t0 - t2, t1m3 = (t1 - t3) * imag;
a[j0] = t0p2 + t1p3, a[j1] = t0p2 - t1p3;
a[j2] = t0m2 + t1m3, a[j3] = t0m2 - t1m3;
}
}
// jh >= 1
mint ww = one, xx = one * dw[2], wx = one;
for (int jh = 4; jh < u;) {
ww = xx * xx, wx = ww * xx;
int j0 = jh * v;
int je = j0 + v;
int j2 = je + v;
for (; j0 < je; ++j0, ++j2) {
mint t0 = a[j0], t1 = a[j0 + v] * xx, t2 = a[j2] * ww,
t3 = a[j2 + v] * wx;
mint t0p2 = t0 + t2, t1p3 = t1 + t3;
mint t0m2 = t0 - t2, t1m3 = (t1 - t3) * imag;
a[j0] = t0p2 + t1p3, a[j0 + v] = t0p2 - t1p3;
a[j2] = t0m2 + t1m3, a[j2 + v] = t0m2 - t1m3;
}
xx *= dw[__builtin_ctzll((jh += 4))];
}
u <<= 2;
v >>= 2;
}
}
void ifft4(vector<mint> &a, int k) {
int u = 1 << (k - 2);
int v = 1;
mint one = mint(1);
mint imag = dy[1];
while (u) {
// jh = 0
{
int j0 = 0;
int j1 = v;
int j2 = v + v;
int j3 = j2 + v;
for (; j0 < v; ++j0, ++j1, ++j2, ++j3) {
mint t0 = a[j0], t1 = a[j1], t2 = a[j2], t3 = a[j3];
mint t0p1 = t0 + t1, t2p3 = t2 + t3;
mint t0m1 = t0 - t1, t2m3 = (t2 - t3) * imag;
a[j0] = t0p1 + t2p3, a[j2] = t0p1 - t2p3;
a[j1] = t0m1 + t2m3, a[j3] = t0m1 - t2m3;
}
}
// jh >= 1
mint ww = one, xx = one * dy[2], yy = one;
u <<= 2;
for (int jh = 4; jh < u;) {
ww = xx * xx, yy = xx * imag;
int j0 = jh * v;
int je = j0 + v;
int j2 = je + v;
for (; j0 < je; ++j0, ++j2) {
mint t0 = a[j0], t1 = a[j0 + v], t2 = a[j2], t3 = a[j2 + v];
mint t0p1 = t0 + t1, t2p3 = t2 + t3;
mint t0m1 = (t0 - t1) * xx, t2m3 = (t2 - t3) * yy;
a[j0] = t0p1 + t2p3, a[j2] = (t0p1 - t2p3) * ww;
a[j0 + v] = t0m1 + t2m3, a[j2 + v] = (t0m1 - t2m3) * ww;
}
xx *= dy[__builtin_ctzll(jh += 4)];
}
u >>= 4;
v <<= 2;
}
if (k & 1) {
u = 1 << (k - 1);
for (int j = 0; j < u; ++j) {
mint ajv = a[j] - a[j + u];
a[j] += a[j + u];
a[j + u] = ajv;
}
}
}
vector<mint> multiply(const vector<mint> &a, const vector<mint> &b) {
int l = a.size() + b.size() - 1;
int k = 2, M = 4;
while (M < l) M <<= 1, ++k;
setwy(k);
vector<mint> s(M), t(M);
for (int i = 0; i < (int)a.size(); ++i) s[i] = a[i];
for (int i = 0; i < (int)b.size(); ++i) t[i] = b[i];
fft4(s, k);
fft4(t, k);
for (int i = 0; i < M; ++i) s[i] *= t[i];
ifft4(s, k);
s.resize(l);
mint invm = mint(M).inverse();
for (int i = 0; i < l; ++i) s[i] *= invm;
return s;
}
};
int main() {
constexpr uint32_t MOD = 998244353;
using mint = LazyMontgomeryModInt<MOD>;
NTT<mint> ntt;
// int N, M, n;
// fastin.scan_serial(N);
// fastin.scan_serial(M);
// vector<mint> a(N), b(M);
// for (int i = 0; i < N; ++i) {
// fastin.scan_serial(n);
// a[i].a = mint::reduce(uint64_t(n) * mint::n2);
// }
// for (int i = 0; i < M; ++i) {
// fastin.scan_serial(n);
// b[i].a = mint::reduce(uint64_t(n) * mint::n2);
// }
// auto c = ntt.multiply(a, b);
// fastout.set_end(' ');
// int l = N + M - 2;
// for (int i = 0; i <= l; ++i) {
// if (i == l) fastout.set_end('\n');
// fastout.println(c[i].get());
// }
int N, Q;
fastin.scan_serial(N);
fastin.scan_serial(Q);
int size = 1;
while (size < N) size <<= 1;
std::vector<std::vector<mint>> vec(size << 1);
for (int i = 0; i < N; ++i) {
uint64_t x;
fastin.scan_parallel(x);
auto &v = vec[size + i];
v.resize(2);
v[0].a = mint::reduce(((x - 1) % MOD) * mint::n2);
v[1].a = mint::reduce(uint64_t(1) * mint::n2);
}
for (int i = size - 1; i > 0; --i) {
const auto &l = vec[i << 1 | 0];
const auto &r = vec[i << 1 | 1];
if (l.empty()) {
vec[i] = r;
}
else if (r.empty()) {
vec[i] = l;
}
else {
vec[i] = ntt.multiply(l, r);
}
}
while(Q--) {
int x;
fastin.scan_serial(x);
fastout.println(vec[1][x].get());
}
return 0;
}
KoD