結果
| 問題 |
No.2556 Increasing Matrix
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-12-08 22:12:58 |
| 言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 5,258 ms / 6,000 ms |
| コード長 | 9,771 bytes |
| コンパイル時間 | 2,451 ms |
| コンパイル使用メモリ | 191,432 KB |
| 実行使用メモリ | 32,744 KB |
| 最終ジャッジ日時 | 2024-09-27 03:13:35 |
| 合計ジャッジ時間 | 27,559 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 22 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int p=998244353;
int po(int a,int b) {if(b==0) return 1; if(b==1) return a; if(b%2==0) {int u=po(a,b/2);return (u*u)%p;} else {int u=po(a,b-1);return (a*u)%p;}}
int inv1(int x) {return po(x,p-2);}
#include <bits/stdc++.h>
using namespace std;
const int md = 998244353;
namespace faq{
inline void add(int &x, int y) {
x += y;
if (x >= md) {
x -= md;
}
x%=md;
}
inline void sub(int &x, int y) {
x -= y;
if (x < 0) {
x += md;
}
x%=md;
}
inline int mul(int x, int y) {
return (long long) x * y % md;
}
inline int power(int x, int y) {
int res = 1;
for (; y; y >>= 1, x = mul(x, x)) {
if (y & 1) {
res = mul(res, x);
}
}
return res;
}
inline int inv(int a) {
a %= md;
if (a < 0) {
a += md;
}
int b = md, u = 0, v = 1;
while (a) {
int t = b / a;
b -= t * a;
swap(a, b);
u -= t * v;
swap(u, v);
}
if (u < 0) {
u += md;
}
return u;
}
namespace ntt {
int base = 1, root = -1, max_base = -1;
vector<int> rev = {0, 1}, roots = {0, 1};
void init() {
int temp = md - 1;
max_base = 0;
while (temp % 2 == 0) {
temp >>= 1;
++max_base;
}
root = 2;
while (true) {
if (power(root, 1 << max_base) == 1 && power(root, 1 << max_base - 1) != 1) {
break;
}
++root;
}
}
void ensure_base(int nbase) {
if (max_base == -1) {
init();
}
if (nbase <= base) {
return;
}
assert(nbase <= max_base);
rev.resize(1 << nbase);
for (int i = 0; i < 1 << nbase; ++i) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << nbase - 1;
}
roots.resize(1 << nbase);
while (base < nbase) {
int z = power(root, 1 << max_base - 1 - base);
for (int i = 1 << base - 1; i < 1 << base; ++i) {
roots[i << 1] = roots[i];
roots[i << 1 | 1] = mul(roots[i], z);
}
++base;
}
}
void dft(vector<int> &a) {
int n = a.size(), zeros = __builtin_ctz(n);
ensure_base(zeros);
int shift = base - zeros;
for (int i = 0; i < n; ++i) {
if (i < rev[i] >> shift) {
swap(a[i], a[rev[i] >> shift]);
}
}
for (int i = 1; i < n; i <<= 1) {
for (int j = 0; j < n; j += i << 1) {
for (int k = 0; k < i; ++k) {
int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
a[j + k] = (x + y) % md;
a[j + k + i] = (x + md - y) % md;
}
}
}
}
vector<int> multiply(vector<int> a, vector<int> b) {
int need = a.size() + b.size() - 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
bool equal = a == b;
dft(a);
if (equal) {
b = a;
} else {
dft(b);
}
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(a[i], b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(need);
return a;
}
vector<int> inverse(vector<int> a) {
int n = a.size(), m = n + 1 >> 1;
if (n == 1) {
return vector<int>(1, inv(a[0]));
} else {
vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
int need = n << 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
a.resize(sz);
b.resize(sz);
dft(a);
dft(b);
int inv_sz = inv(sz);
for (int i = 0; i < sz; ++i) {
a[i] = mul(mul(md + 2 - mul(a[i], b[i]), b[i]), inv_sz);
}
reverse(a.begin() + 1, a.end());
dft(a);
a.resize(n);
return a;
}
}
}
using ntt::multiply;
using ntt::inverse;
vector<int>& operator += (vector<int> &a, const vector<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
add(a[i], b[i]);
}
return a;
}
vector<int> operator + (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c += b;
}
vector<int>& operator -= (vector<int> &a, const vector<int> &b) {
if (a.size() < b.size()) {
a.resize(b.size());
}
for (int i = 0; i < b.size(); ++i) {
sub(a[i], b[i]);
}
return a;
}
vector<int> operator - (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c -= b;
}
vector<int>& operator *= (vector<int> &a, const vector<int> &b) {
if (min(a.size(), b.size()) < 128) {
vector<int> c = a;
a.assign(a.size() + b.size() - 1, 0);
for (int i = 0; i < c.size(); ++i) {
for (int j = 0; j < b.size(); ++j) {
add(a[i + j], mul(c[i], b[j]));
}
}
} else {
a = multiply(a, b);
}
return a;
}
vector<int> operator * (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c *= b;
}
vector<int>& operator /= (vector<int> &a, const vector<int> &b) {
int n = a.size(), m = b.size();
if (n < m) {
a.clear();
} else {
vector<int> c = b;
reverse(a.begin(), a.end());
reverse(c.begin(), c.end());
c.resize(n - m + 1);
a *= inverse(c);
a.erase(a.begin() + n - m + 1, a.end());
reverse(a.begin(), a.end());
}
return a;
}
vector<int> operator / (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c /= b;
}
vector<int>& operator %= (vector<int> &a, const vector<int> &b) {
int n = a.size(), m = b.size();
if (n >= m) {
vector<int> c = (a / b) * b;
a.resize(m - 1);
for (int i = 0; i < m - 1; ++i) {
sub(a[i], c[i]);
}
}
return a;
}
vector<int> operator % (const vector<int> &a, const vector<int> &b) {
vector<int> c = a;
return c %= b;
}
vector<int> derivative(const vector<int> &a) {
int n = a.size();
vector<int> b(n - 1);
for (int i = 1; i < n; ++i) {
b[i - 1] = mul(a[i], i);
}
return b;
}
vector<int> primitive(const vector<int> &a) {
int n = a.size();
vector<int> b(n + 1), invs(n + 1);
for (int i = 1; i <= n; ++i) {
invs[i] = i == 1 ? 1 : mul(md - md / i, invs[md % i]);
b[i] = mul(a[i - 1], invs[i]);
}
return b;
}
vector<int> logarithm(const vector<int> &a) {
vector<int> b = primitive(derivative(a) * inverse(a));
b.resize(a.size());
return b;
}
vector<int> exponent(const vector<int> &a) {
vector<int> b(1, 1);
while (b.size() < a.size()) {
vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
add(c[0], 1);
vector<int> old_b = b;
b.resize(b.size() << 1);
c -= logarithm(b);
c *= old_b;
for (int i = b.size() >> 1; i < b.size(); ++i) {
b[i] = c[i];
}
}
b.resize(a.size());
return b;
}
vector<int> power(const vector<int> &a, int m) {
int n = a.size(), p = -1;
vector<int> b(n);
for (int i = 0; i < n; ++i) {
if (a[i]) {
p = i;
break;
}
}
if (p == -1) {
b[0] = !m;
return b;
}
if ((long long) m * p >= n) {
return b;
}
int mu = power(a[p], m), di = inv(a[p]);
vector<int> c(n - m * p);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(a[i + p], di);
}
c = logarithm(c);
for (int i = 0; i < n - m * p; ++i) {
c[i] = mul(c[i], m);
}
c = exponent(c);
for (int i = 0; i < n - m * p; ++i) {
b[i + m * p] = mul(c[i], mu);
}
return b;
}
vector<int> sqrt(const vector<int> &a) {
vector<int> b(1, 1);
while (b.size() < a.size()) {
vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
vector<int> old_b = b;
b.resize(b.size() << 1);
c *= inverse(b);
for (int i = b.size() >> 1; i < b.size(); ++i) {
b[i] = mul(c[i], md + 1 >> 1);
}
}
b.resize(a.size());
return b;
}
vector<int> multiply_all(int l, int r, vector<vector<int>> &all) {
if (l > r) {
return vector<int>();
} else if (l == r) {
return all[l];
} else {
int y = (l + r) >> 1;
return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
}
}
vector<int> evaluate(const vector<int> &f, const vector<int> &x) {
int n = x.size();
if (!n) {
return vector<int>();
}
vector<vector<int>> up(n * 2);
for (int i = 0; i < n; ++i) {
up[i + n] = vector<int>{(md - x[i]) % md, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
vector<vector<int>> down(n * 2);
down[1] = f % up[1];
for (int i = 2; i < n * 2; ++i) {
down[i] = down[i >> 1] % up[i];
}
vector<int> y(n);
for (int i = 0; i < n; ++i) {
y[i] = down[i + n][0];
}
return y;
}
vector<int> interpolate(const vector<int> &x, const vector<int> &y) {
int n = x.size();
vector<vector<int>> up(n * 2);
for (int i = 0; i < n; ++i) { up[i + n] = vector<int>{(md - x[i]) % md, 1};
}
for (int i = n - 1; i; --i) {
up[i] = up[i << 1] * up[i << 1 | 1];
}
vector<int> a = evaluate(derivative(up[1]), x);
for (int i = 0; i < n; ++i) {
a[i] = mul(y[i], inv(a[i]));
}
vector<vector<int>> down(n * 2);
for (int i = 0; i < n; ++i) {
down[i + n] = vector<int>(1, a[i]);
}
for (int i = n - 1; i; --i) {
down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
}
return down[1];
}
}
using namespace faq;
using namespace faq::ntt;
#define all(x) (x).begin(),(x).end()
int slv(vector<int> a)
{
if(a.size()==1) return 1;
vector<int> a1,a2;
for(int i=0;i<a.size();++i) {if(2*i<a.size()) a1.push_back(a[i]); else a2.push_back(a[i]);}
vector<vector<int> > h;
for(int i=0;i<a1.size();++i) {h.push_back({(2*p-a1[i])%p,1});}
vector<int> f=multiply_all(0,h.size()-1,h);
vector<int> g=evaluate(f,a2);
int res=slv(a1)*slv(a2);res%=p;
for(int val:g) {res*=(val%p);res%=p;}
return res;
}
int32_t main()
{
ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0);
int n;cin>>n;vector<int> a(n);
for(int i=0;i<n;++i) {cin>>a[i];a[i]+=i;}
int res=1;
for(int i=1;i<=n-1;++i) {res*=po(inv1(i),n-i);res%=p;}
res*=slv(a);res%=p;
cout<<((res*res)%p+p)%p;
return 0;
}