結果
| 問題 |
No.3247 Multiplication 8 2
|
| コンテスト | |
| ユーザー |
ゼット
|
| 提出日時 | 2025-08-24 01:41:53 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 20,857 bytes |
| コンパイル時間 | 3,653 ms |
| コンパイル使用メモリ | 296,784 KB |
| 実行使用メモリ | 104,188 KB |
| 最終ジャッジ日時 | 2025-08-24 01:43:16 |
| 合計ジャッジ時間 | 82,409 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 22 TLE * 6 |
ソースコード
/*
class FFT:
def primitive_root_constexpr(self, m):
if m == 2:
return 1
if m == 167772161:
return 3
if m == 469762049:
return 3
if m == 754974721:
return 11
if m == 998244353:
return 3
divs = [0] * 20
divs[0] = 2
cnt = 1
x = (m - 1) // 2
while x % 2 == 0:
x //= 2
i = 3
while i * i <= x:
if x % i == 0:
divs[cnt] = i
cnt += 1
while x % i == 0:
x //= i
i += 2
if x > 1:
divs[cnt] = x
cnt += 1
g = 2
while 1:
ok = True
for i in range(cnt):
if pow(g, (m - 1) // divs[i], m) == 1:
ok = False
break
if ok:
return g
g += 1
def bsf(self, x):
res = 0
while x % 2 == 0:
res += 1
x //= 2
return res
rank2 = 0
root = []
iroot = []
rate2 = []
irate2 = []
rate3 = []
irate3 = []
def __init__(self, MOD):
self.mod = MOD
self.g = self.primitive_root_constexpr(self.mod)
self.rank2 = self.bsf(self.mod - 1)
self.root = [0 for i in range(self.rank2 + 1)]
self.iroot = [0 for i in range(self.rank2 + 1)]
self.rate2 = [0 for i in range(self.rank2)]
self.irate2 = [0 for i in range(self.rank2)]
self.rate3 = [0 for i in range(self.rank2 - 1)]
self.irate3 = [0 for i in range(self.rank2 - 1)]
self.root[self.rank2] = pow(self.g, (self.mod - 1) >> self.rank2, self.mod)
self.iroot[self.rank2] = pow(self.root[self.rank2], self.mod - 2, self.mod)
for i in range(self.rank2 - 1, -1, -1):
self.root[i] = (self.root[i + 1] ** 2) % self.mod
self.iroot[i] = (self.iroot[i + 1] ** 2) % self.mod
prod = 1
iprod = 1
for i in range(self.rank2 - 1):
self.rate2[i] = (self.root[i + 2] * prod) % self.mod
self.irate2[i] = (self.iroot[i + 2] * iprod) % self.mod
prod = (prod * self.iroot[i + 2]) % self.mod
iprod = (iprod * self.root[i + 2]) % self.mod
prod = 1
iprod = 1
for i in range(self.rank2 - 2):
self.rate3[i] = (self.root[i + 3] * prod) % self.mod
self.irate3[i] = (self.iroot[i + 3] * iprod) % self.mod
prod = (prod * self.iroot[i + 3]) % self.mod
iprod = (iprod * self.root[i + 3]) % self.mod
def butterfly(self, a):
n = len(a)
h = (n - 1).bit_length()
LEN = 0
while LEN < h:
if h - LEN == 1:
p = 1 << (h - LEN - 1)
rot = 1
for s in range(1 << LEN):
offset = s << (h - LEN)
for i in range(p):
l = a[i + offset]
r = a[i + offset + p] * rot
a[i + offset] = (l + r) % self.mod
a[i + offset + p] = (l - r) % self.mod
rot *= self.rate2[(~s & -~s).bit_length() - 1]
rot %= self.mod
LEN += 1
else:
p = 1 << (h - LEN - 2)
rot = 1
imag = self.root[2]
for s in range(1 << LEN):
rot2 = (rot * rot) % self.mod
rot3 = (rot2 * rot) % self.mod
offset = s << (h - LEN)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p] * rot
a2 = a[i + offset + 2 * p] * rot2
a3 = a[i + offset + 3 * p] * rot3
a1na3imag = (a1 - a3) % self.mod * imag
a[i + offset] = (a0 + a2 + a1 + a3) % self.mod
a[i + offset + p] = (a0 + a2 - a1 - a3) % self.mod
a[i + offset + 2 * p] = (a0 - a2 + a1na3imag) % self.mod
a[i + offset + 3 * p] = (a0 - a2 - a1na3imag) % self.mod
rot *= self.rate3[(~s & -~s).bit_length() - 1]
rot %= self.mod
LEN += 2
def butterfly_inv(self, a):
n = len(a)
h = (n - 1).bit_length()
LEN = h
while LEN:
if LEN == 1:
p = 1 << (h - LEN)
irot = 1
for s in range(1 << (LEN - 1)):
offset = s << (h - LEN + 1)
for i in range(p):
l = a[i + offset]
r = a[i + offset + p]
a[i + offset] = (l + r) % self.mod
a[i + offset + p] = (l - r) * irot % self.mod
irot *= self.irate2[(~s & -~s).bit_length() - 1]
irot %= self.mod
LEN -= 1
else:
p = 1 << (h - LEN)
irot = 1
iimag = self.iroot[2]
for s in range(1 << (LEN - 2)):
irot2 = (irot * irot) % self.mod
irot3 = (irot * irot2) % self.mod
offset = s << (h - LEN + 2)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p]
a2 = a[i + offset + 2 * p]
a3 = a[i + offset + 3 * p]
a2na3iimag = (a2 - a3) * iimag % self.mod
a[i + offset] = (a0 + a1 + a2 + a3) % self.mod
a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % self.mod
a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2 % self.mod
a[i + offset + 3 * p] = (
(a0 - a1 - a2na3iimag) * irot3 % self.mod
)
irot *= self.irate3[(~s & -~s).bit_length() - 1]
irot %= self.mod
LEN -= 2
def convolution(self, a, b):
n = len(a)
m = len(b)
if not (a) or not (b):
return []
if min(n, m) <= 40:
res = [0] * (n + m - 1)
for i in range(n):
for j in range(m):
res[i + j] += a[i] * b[j]
res[i + j] %= self.mod
return res
z = 1 << ((n + m - 2).bit_length())
a = a + [0] * (z - n)
b = b + [0] * (z - m)
self.butterfly(a)
self.butterfly(b)
c = [(a[i] * b[i]) % self.mod for i in range(z)]
self.butterfly_inv(c)
iz = pow(z, self.mod - 2, self.mod)
for i in range(n + m - 1):
c[i] = (c[i] * iz) % self.mod
return c[: n + m - 1]
N,K=map(int,input().split())
A=list(map(int,input().split()))
v=[0]*(N+1)
t=0
z=1
v[0]=1
dp=[0]*(N+1)
dp[0]=1
mod=998244353
p=[0]*(N+1)
G=[[] for i in range(N+1)]
G[0].append(0)
for i in range(N):
x=A[i]
z*=x
if z==8:
t+=1
z=1
p[i+1]=t
if z==1:
dp[i+1]+=v[t-1]
v[t]+=dp[i+1]
v[t]%=mod
G[t].append(i+1)
if z!=1:
print(0)
exit()
G[t]=[N]
G[0]=[0]
F=t
cp=[0]*(N+1)
cp[N]=1
v2=[0]*(N+1)
t=0
z=1
v2[0]=1
for i in range(N-1,-1,-1):
x=A[i]
z*=x
if z==8:
t+=1
z=1
if z==1:
cp[i]+=v2[t-1]
v2[t]+=cp[i]
v2[t]%=mod
result=0
Z=FFT(mod)
for e in range(1,F+1):
d=G[e][-1]-G[e-1][0]+1
u=[0]*d
v=[0]*d
for pos in G[e-1]:
a=pos-G[e-1][0]
u[a]=cp[pos]
v[G[e][-1]-pos]=dp[pos]
for pos in G[e]:
a=pos-G[e-1][0]
u[a]=cp[pos]
v[G[e][-1]-pos]=dp[pos]
h=Z.convolution(u,v)
for y in range(d,len(h)):
result+=pow((y-(d-1)),K,mod)*h[y]
result%=mod
for e in range(F+1):
d=G[e][-1]-G[e][0]+1
u=[0]*d
v=[0]*d
for pos in G[e]:
a=pos-G[e][0]
u[a]=cp[pos]
v[G[e][-1]-pos]=dp[pos]
h=Z.convolution(u,v)
for y in range(d,len(h)):
if e==0 or e==F:
result-=pow((y-(d-1)),K,mod)*h[y]
result%=mod
else:
result-=2*pow((y-(d-1)),K,mod)*h[y]
result%=mod
print(result)
*/
#include <bits/stdc++.h>
using namespace std;
struct FFT {
int mod, g, rank2;
vector<int> root, iroot, rate2, irate2, rate3, irate3;
static long long modpow_ll(long long a, long long e, long long m){
long long r=1% m;
while(e){
if(e&1) r = (r*a) % m;
a = (a*a) % m;
e >>= 1;
}
return r;
}
int primitive_root_constexpr(int m){
if (m == 2) return 1;
if (m == 167772161) return 3;
if (m == 469762049) return 3;
if (m == 754974721) return 11;
if (m == 998244353) return 3;
vector<int> divs(20);
divs[0]=2; int cnt=1;
int x=(m-1)/2;
while(x%2==0) x/=2;
for(int i=3; 1LL*i*i<=x; i+=2){
if(x%i==0){
divs[cnt++]=i;
while(x%i==0) x/=i;
}
}
if(x>1) divs[cnt++]=x;
for(int g=2;;g++){
bool ok=true;
for(int i=0;i<cnt;i++){
if(modpow_ll(g,(m-1)/divs[i],m)==1){
ok=false; break;
}
}
if(ok) return g;
}
}
static int bsf(int x){
int res=0;
while(x%2==0){ res++; x/=2; }
return res;
}
FFT(int MOD){
mod = MOD;
g = primitive_root_constexpr(mod);
rank2 = bsf(mod-1);
root.assign(rank2+1,0);
iroot.assign(rank2+1,0);
rate2.assign(rank2,0);
irate2.assign(rank2,0);
rate3.assign(rank2-1,0);
irate3.assign(rank2-1,0);
root[rank2] = (int)modpow_ll(g, (mod-1)>>rank2, mod);
iroot[rank2] = (int)modpow_ll(root[rank2], mod-2, mod);
for(int i=rank2-1;i>=0;i--){
root[i] = int(1LL*root[i+1]*root[i+1] % mod);
iroot[i] = int(1LL*iroot[i+1]*iroot[i+1] % mod);
}
long long prod=1, iprod=1;
for(int i=0;i<rank2-1;i++){
rate2[i] = int(1LL*root[i+2]*prod % mod);
irate2[i] = int(1LL*iroot[i+2]*iprod % mod);
prod = prod * iroot[i+2] % mod;
iprod = iprod * root[i+2] % mod;
}
prod=1; iprod=1;
for(int i=0;i<rank2-2;i++){
rate3[i] = int(1LL*root[i+3]*prod % mod);
irate3[i] = int(1LL*iroot[i+3]*iprod % mod);
prod = prod * iroot[i+3] % mod;
iprod = iprod * root[i+3] % mod;
}
}
static inline int addmod(int x, int y, int mod){ int s=x+y; if(s>=mod) s-=mod; return s; }
static inline int submod(int x, int y, int mod){ int d=x-y; if(d<0) d+=mod; return d; }
void butterfly(vector<int>& a){
int n = (int)a.size();
int h = 0; // (n-1).bit_length()
while((1<<h) < n) ++h;
int LEN = 0;
while(LEN < h){
if(h - LEN == 1){
int p = 1 << (h - LEN - 1);
int rot = 1;
for(unsigned s=0; s<(1u<<LEN); s++){
int offset = int(s) << (h - LEN);
for(int i=0;i<p;i++){
int l = a[i + offset];
int r = int(1LL * a[i + offset + p] * rot % mod);
a[i + offset] = addmod(l, r, mod);
a[i + offset + p] = submod(l, r, mod);
}
// idx = ( (~s & -~s).bit_length() - 1 ) <=> ctz(~s)
int idx = __builtin_ctz(~s);
rot = int(1LL * rot * rate2[idx] % mod);
}
LEN += 1;
} else {
int p = 1 << (h - LEN - 2);
int rot = 1;
int imag = root[2];
for(unsigned s=0; s<(1u<<LEN); s++){
int rot2 = int(1LL*rot*rot % mod);
int rot3 = int(1LL*rot2*rot % mod);
int offset = int(s) << (h - LEN);
for(int i=0;i<p;i++){
int a0 = a[i + offset];
int a1 = int(1LL * a[i + offset + p ] * rot % mod);
int a2 = int(1LL * a[i + offset + 2 * p] * rot2 % mod);
int a3 = int(1LL * a[i + offset + 3 * p] * rot3 % mod);
int a1na3 = submod(a1, a3, mod);
long long a1na3imag = 1LL * a1na3 * imag % mod;
a[i + offset] = (((a0 + a2) >= mod ? a0 + a2 - mod : a0 + a2) + ((a1 + a3) >= mod ? a1 + a3 - mod : a1 + a3));
if(a[i + offset] >= mod) a[i + offset] -= mod;
int tmp = ((a0 + a2) >= mod ? a0 + a2 - mod : a0 + a2);
int t2 = ((a1 + a3) >= mod ? a1 + a3 - mod : a1 + a3);
a[i + offset + p] = submod(tmp, t2, mod);
int t3 = submod(a0, a2, mod);
a[i + offset + 2 * p] = int((t3 + a1na3imag) % mod);
if(a[i + offset + 2 * p] < 0) a[i + offset + 2 * p] += mod;
a[i + offset + 3 * p] = int((t3 - a1na3imag) % mod);
if(a[i + offset + 3 * p] < 0) a[i + offset + 3 * p] += mod;
}
int idx = __builtin_ctz(~s);
rot = int(1LL * rot * rate3[idx] % mod);
}
LEN += 2;
}
}
}
void butterfly_inv(vector<int>& a){
int n = (int)a.size();
int h = 0; while((1<<h) < n) ++h;
int LEN = h;
while(LEN){
if(LEN == 1){
int p = 1 << (h - LEN);
int irot = 1;
for(unsigned s=0; s<(1u<<(LEN-1)); s++){
int offset = int(s) << (h - LEN + 1);
for(int i=0;i<p;i++){
int l = a[i + offset];
int r = a[i + offset + p];
a[i + offset] = addmod(l, r, mod);
a[i + offset + p] = int(1LL * submod(l, r, mod) * irot % mod);
}
int idx = __builtin_ctz(~s);
irot = int(1LL * irot * irate2[idx] % mod);
}
LEN -= 1;
} else {
int p = 1 << (h - LEN);
int irot = 1;
int iimag = iroot[2];
for(unsigned s=0; s<(1u<<(LEN-2)); s++){
int irot2 = int(1LL*irot*irot % mod);
int irot3 = int(1LL*irot*irot2 % mod);
int offset = int(s) << (h - LEN + 2);
for(int i=0;i<p;i++){
int a0 = a[i + offset];
int a1 = a[i + offset + p];
int a2 = a[i + offset + 2 * p];
int a3 = a[i + offset + 3 * p];
int a2na3 = submod(a2, a3, mod);
int a2na3iimag = int(1LL * a2na3 * iimag % mod);
a[i + offset] = (((a0 + a1) >= mod ? a0 + a1 - mod : a0 + a1) + ((a2 + a3) >= mod ? a2 + a3 - mod : a2 + a3));
if(a[i + offset] >= mod) a[i + offset] -= mod;
int t1 = submod(a0, a1, mod);
a[i + offset + p] = int(1LL * (t1 + a2na3iimag) % mod);
if(a[i + offset + p] < 0) a[i + offset + p] += mod;
a[i + offset + p] = int(1LL * a[i + offset + p] * irot % mod);
int t2 = submod((a0 + a1) >= mod ? a0 + a1 - mod : a0 + a1,
(a2 + a3) >= mod ? a2 + a3 - mod : a2 + a3, mod);
a[i + offset + 2 * p] = int(1LL * t2 * irot2 % mod);
int t3 = submod(a0, a1, mod);
int t4 = submod(t3, a2na3iimag, mod);
a[i + offset + 3 * p] = int(1LL * t4 * irot3 % mod);
}
int idx = __builtin_ctz(~s);
irot = int(1LL * irot * irate3[idx] % mod);
}
LEN -= 2;
}
}
}
static int ceil_pow2(int n){
int x=0; while((1<<x) < n) ++x; return x;
}
vector<int> convolution(vector<int> a, vector<int> b){
int n = (int)a.size(), m = (int)b.size();
if(n==0 || m==0) return {};
if(min(n,m) <= 40){
vector<int> res(n+m-1, 0);
for(int i=0;i<n;i++) if(a[i]){
for(int j=0;j<m;j++){
res[i+j] = (res[i+j] + (int)(1LL*a[i]*b[j] % mod)) % mod;
}
}
return res;
}
int z = 1 << ceil_pow2(n + m - 1);
a.resize(z, 0);
b.resize(z, 0);
butterfly(a);
butterfly(b);
vector<int> c(z);
for(int i=0;i<z;i++) c[i] = (int)(1LL*a[i]*b[i] % mod);
butterfly_inv(c);
int iz = (int)modpow_ll(z, mod-2, mod);
c.resize(n+m-1);
for(int i=0;i<n+m-1;i++) c[i] = (int)(1LL*c[i]*iz % mod);
return c;
}
};
static inline long long modpow(long long a, long long e, long long mod){
long long r=1%mod;
while(e){
if(e&1) r = r*a % mod;
a = a*a % mod;
e >>= 1;
}
return r;
}
// Helper: Python-like v[idx] supporting negative idx (only reads; Python uses v[-1])
template<class T>
static inline T vec_at_with_neg(const vector<T>& v, long long idx){
long long n = (long long)v.size();
if(idx >= 0) return v[(size_t)idx];
long long j = n + idx; // idx is negative
if(j < 0) j = 0; // safety (shouldn't happen in this code path)
return v[(size_t)j];
}
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
long long K;
if(!(cin >> N >> K)) return 0;
vector<int> A(N);
for(int i=0;i<N;i++) cin >> A[i];
const int mod = 998244353;
vector<long long> v(N+1, 0);
int t = 0;
long long z = 1;
v[0] = 1;
vector<long long> dp(N+1, 0);
dp[0] = 1;
vector<long long> p(N+1, 0);
vector<vector<int>> G(N+1);
G[0].push_back(0);
for(int i=0;i<N;i++){
int x = A[i];
z *= x;
if(z == 8){
t += 1;
z = 1;
}
p[i+1] = t;
if(z == 1){
dp[i+1] += vec_at_with_neg(v, t-1); // v[t-1] (Python allows -1)
v[t] += dp[i+1];
v[t] %= mod;
G[t].push_back(i+1);
}
}
if(z != 1){
cout << 0 << '\n';
return 0;
}
G[t] = { N };
G[0] = { 0 };
int F = t;
vector<long long> cp(N+1, 0);
cp[N] = 1;
vector<long long> v2(N+1, 0);
t = 0; z = 1; v2[0] = 1;
for(int i=N-1;i>=0;i--){
int x = A[i];
z *= x;
if(z == 8){
t += 1;
z = 1;
}
if(z == 1){
cp[i] += vec_at_with_neg(v2, t-1); // v2[t-1]
v2[t] += cp[i];
v2[t] %= mod;
}
}
long long result = 0;
FFT Z(mod);
for(int e=1; e<=F; e++){
int d = G[e].back() - G[e-1].front() + 1;
vector<int> u(d, 0), w(d, 0);
for(int pos : G[e-1]){
int a = pos - G[e-1].front();
u[a] = (int)(cp[pos] % mod);
w[G[e].back() - pos] = (int)(dp[pos] % mod);
}
for(int pos : G[e]){
int a = pos - G[e-1].front();
u[a] = (int)(cp[pos] % mod);
w[G[e].back() - pos] = (int)(dp[pos] % mod);
}
vector<int> h = Z.convolution(u, w);
for(int y=d; y<(int)h.size(); y++){
long long base = y - (d - 1);
long long add = modpow(base % mod, K, mod);
result = (result + add * h[y]) % mod;
}
}
for(int e=0; e<=F; e++){
int d = G[e].back() - G[e].front() + 1;
vector<int> u(d, 0), w(d, 0);
for(int pos : G[e]){
int a = pos - G[e].front();
u[a] = (int)(cp[pos] % mod);
w[G[e].back() - pos] = (int)(dp[pos] % mod);
}
vector<int> h = Z.convolution(u, w);
for(int y=d; y<(int)h.size(); y++){
long long base = y - (d - 1);
long long term = modpow(base % mod, K, mod) * h[y] % mod;
if(e==0 || e==F){
result = (result - term) % mod;
}else{
result = (result - (2*term)%mod) % mod;
}
if(result < 0) result += mod;
}
}
cout << (result % mod + mod) % mod << '\n';
return 0;
}
ゼット