結果
問題 |
No.3247 Multiplication 8 2
|
ユーザー |
![]() |
提出日時 | 2025-08-24 01:41:53 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
TLE
|
実行時間 | - |
コード長 | 20,857 bytes |
コンパイル時間 | 3,653 ms |
コンパイル使用メモリ | 296,784 KB |
実行使用メモリ | 104,188 KB |
最終ジャッジ日時 | 2025-08-24 01:43:16 |
合計ジャッジ時間 | 82,409 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 22 TLE * 6 |
ソースコード
/* class FFT: def primitive_root_constexpr(self, m): if m == 2: return 1 if m == 167772161: return 3 if m == 469762049: return 3 if m == 754974721: return 11 if m == 998244353: return 3 divs = [0] * 20 divs[0] = 2 cnt = 1 x = (m - 1) // 2 while x % 2 == 0: x //= 2 i = 3 while i * i <= x: if x % i == 0: divs[cnt] = i cnt += 1 while x % i == 0: x //= i i += 2 if x > 1: divs[cnt] = x cnt += 1 g = 2 while 1: ok = True for i in range(cnt): if pow(g, (m - 1) // divs[i], m) == 1: ok = False break if ok: return g g += 1 def bsf(self, x): res = 0 while x % 2 == 0: res += 1 x //= 2 return res rank2 = 0 root = [] iroot = [] rate2 = [] irate2 = [] rate3 = [] irate3 = [] def __init__(self, MOD): self.mod = MOD self.g = self.primitive_root_constexpr(self.mod) self.rank2 = self.bsf(self.mod - 1) self.root = [0 for i in range(self.rank2 + 1)] self.iroot = [0 for i in range(self.rank2 + 1)] self.rate2 = [0 for i in range(self.rank2)] self.irate2 = [0 for i in range(self.rank2)] self.rate3 = [0 for i in range(self.rank2 - 1)] self.irate3 = [0 for i in range(self.rank2 - 1)] self.root[self.rank2] = pow(self.g, (self.mod - 1) >> self.rank2, self.mod) self.iroot[self.rank2] = pow(self.root[self.rank2], self.mod - 2, self.mod) for i in range(self.rank2 - 1, -1, -1): self.root[i] = (self.root[i + 1] ** 2) % self.mod self.iroot[i] = (self.iroot[i + 1] ** 2) % self.mod prod = 1 iprod = 1 for i in range(self.rank2 - 1): self.rate2[i] = (self.root[i + 2] * prod) % self.mod self.irate2[i] = (self.iroot[i + 2] * iprod) % self.mod prod = (prod * self.iroot[i + 2]) % self.mod iprod = (iprod * self.root[i + 2]) % self.mod prod = 1 iprod = 1 for i in range(self.rank2 - 2): self.rate3[i] = (self.root[i + 3] * prod) % self.mod self.irate3[i] = (self.iroot[i + 3] * iprod) % self.mod prod = (prod * self.iroot[i + 3]) % self.mod iprod = (iprod * self.root[i + 3]) % self.mod def butterfly(self, a): n = len(a) h = (n - 1).bit_length() LEN = 0 while LEN < h: if h - LEN == 1: p = 1 << (h - LEN - 1) rot = 1 for s in range(1 << LEN): offset = s << (h - LEN) for i in range(p): l = a[i + offset] r = a[i + offset + p] * rot a[i + offset] = (l + r) % self.mod a[i + offset + p] = (l - r) % self.mod rot *= self.rate2[(~s & -~s).bit_length() - 1] rot %= self.mod LEN += 1 else: p = 1 << (h - LEN - 2) rot = 1 imag = self.root[2] for s in range(1 << LEN): rot2 = (rot * rot) % self.mod rot3 = (rot2 * rot) % self.mod offset = s << (h - LEN) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] * rot a2 = a[i + offset + 2 * p] * rot2 a3 = a[i + offset + 3 * p] * rot3 a1na3imag = (a1 - a3) % self.mod * imag a[i + offset] = (a0 + a2 + a1 + a3) % self.mod a[i + offset + p] = (a0 + a2 - a1 - a3) % self.mod a[i + offset + 2 * p] = (a0 - a2 + a1na3imag) % self.mod a[i + offset + 3 * p] = (a0 - a2 - a1na3imag) % self.mod rot *= self.rate3[(~s & -~s).bit_length() - 1] rot %= self.mod LEN += 2 def butterfly_inv(self, a): n = len(a) h = (n - 1).bit_length() LEN = h while LEN: if LEN == 1: p = 1 << (h - LEN) irot = 1 for s in range(1 << (LEN - 1)): offset = s << (h - LEN + 1) for i in range(p): l = a[i + offset] r = a[i + offset + p] a[i + offset] = (l + r) % self.mod a[i + offset + p] = (l - r) * irot % self.mod irot *= self.irate2[(~s & -~s).bit_length() - 1] irot %= self.mod LEN -= 1 else: p = 1 << (h - LEN) irot = 1 iimag = self.iroot[2] for s in range(1 << (LEN - 2)): irot2 = (irot * irot) % self.mod irot3 = (irot * irot2) % self.mod offset = s << (h - LEN + 2) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] a2 = a[i + offset + 2 * p] a3 = a[i + offset + 3 * p] a2na3iimag = (a2 - a3) * iimag % self.mod a[i + offset] = (a0 + a1 + a2 + a3) % self.mod a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % self.mod a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2 % self.mod a[i + offset + 3 * p] = ( (a0 - a1 - a2na3iimag) * irot3 % self.mod ) irot *= self.irate3[(~s & -~s).bit_length() - 1] irot %= self.mod LEN -= 2 def convolution(self, a, b): n = len(a) m = len(b) if not (a) or not (b): return [] if min(n, m) <= 40: res = [0] * (n + m - 1) for i in range(n): for j in range(m): res[i + j] += a[i] * b[j] res[i + j] %= self.mod return res z = 1 << ((n + m - 2).bit_length()) a = a + [0] * (z - n) b = b + [0] * (z - m) self.butterfly(a) self.butterfly(b) c = [(a[i] * b[i]) % self.mod for i in range(z)] self.butterfly_inv(c) iz = pow(z, self.mod - 2, self.mod) for i in range(n + m - 1): c[i] = (c[i] * iz) % self.mod return c[: n + m - 1] N,K=map(int,input().split()) A=list(map(int,input().split())) v=[0]*(N+1) t=0 z=1 v[0]=1 dp=[0]*(N+1) dp[0]=1 mod=998244353 p=[0]*(N+1) G=[[] for i in range(N+1)] G[0].append(0) for i in range(N): x=A[i] z*=x if z==8: t+=1 z=1 p[i+1]=t if z==1: dp[i+1]+=v[t-1] v[t]+=dp[i+1] v[t]%=mod G[t].append(i+1) if z!=1: print(0) exit() G[t]=[N] G[0]=[0] F=t cp=[0]*(N+1) cp[N]=1 v2=[0]*(N+1) t=0 z=1 v2[0]=1 for i in range(N-1,-1,-1): x=A[i] z*=x if z==8: t+=1 z=1 if z==1: cp[i]+=v2[t-1] v2[t]+=cp[i] v2[t]%=mod result=0 Z=FFT(mod) for e in range(1,F+1): d=G[e][-1]-G[e-1][0]+1 u=[0]*d v=[0]*d for pos in G[e-1]: a=pos-G[e-1][0] u[a]=cp[pos] v[G[e][-1]-pos]=dp[pos] for pos in G[e]: a=pos-G[e-1][0] u[a]=cp[pos] v[G[e][-1]-pos]=dp[pos] h=Z.convolution(u,v) for y in range(d,len(h)): result+=pow((y-(d-1)),K,mod)*h[y] result%=mod for e in range(F+1): d=G[e][-1]-G[e][0]+1 u=[0]*d v=[0]*d for pos in G[e]: a=pos-G[e][0] u[a]=cp[pos] v[G[e][-1]-pos]=dp[pos] h=Z.convolution(u,v) for y in range(d,len(h)): if e==0 or e==F: result-=pow((y-(d-1)),K,mod)*h[y] result%=mod else: result-=2*pow((y-(d-1)),K,mod)*h[y] result%=mod print(result) */ #include <bits/stdc++.h> using namespace std; struct FFT { int mod, g, rank2; vector<int> root, iroot, rate2, irate2, rate3, irate3; static long long modpow_ll(long long a, long long e, long long m){ long long r=1% m; while(e){ if(e&1) r = (r*a) % m; a = (a*a) % m; e >>= 1; } return r; } int primitive_root_constexpr(int m){ if (m == 2) return 1; if (m == 167772161) return 3; if (m == 469762049) return 3; if (m == 754974721) return 11; if (m == 998244353) return 3; vector<int> divs(20); divs[0]=2; int cnt=1; int x=(m-1)/2; while(x%2==0) x/=2; for(int i=3; 1LL*i*i<=x; i+=2){ if(x%i==0){ divs[cnt++]=i; while(x%i==0) x/=i; } } if(x>1) divs[cnt++]=x; for(int g=2;;g++){ bool ok=true; for(int i=0;i<cnt;i++){ if(modpow_ll(g,(m-1)/divs[i],m)==1){ ok=false; break; } } if(ok) return g; } } static int bsf(int x){ int res=0; while(x%2==0){ res++; x/=2; } return res; } FFT(int MOD){ mod = MOD; g = primitive_root_constexpr(mod); rank2 = bsf(mod-1); root.assign(rank2+1,0); iroot.assign(rank2+1,0); rate2.assign(rank2,0); irate2.assign(rank2,0); rate3.assign(rank2-1,0); irate3.assign(rank2-1,0); root[rank2] = (int)modpow_ll(g, (mod-1)>>rank2, mod); iroot[rank2] = (int)modpow_ll(root[rank2], mod-2, mod); for(int i=rank2-1;i>=0;i--){ root[i] = int(1LL*root[i+1]*root[i+1] % mod); iroot[i] = int(1LL*iroot[i+1]*iroot[i+1] % mod); } long long prod=1, iprod=1; for(int i=0;i<rank2-1;i++){ rate2[i] = int(1LL*root[i+2]*prod % mod); irate2[i] = int(1LL*iroot[i+2]*iprod % mod); prod = prod * iroot[i+2] % mod; iprod = iprod * root[i+2] % mod; } prod=1; iprod=1; for(int i=0;i<rank2-2;i++){ rate3[i] = int(1LL*root[i+3]*prod % mod); irate3[i] = int(1LL*iroot[i+3]*iprod % mod); prod = prod * iroot[i+3] % mod; iprod = iprod * root[i+3] % mod; } } static inline int addmod(int x, int y, int mod){ int s=x+y; if(s>=mod) s-=mod; return s; } static inline int submod(int x, int y, int mod){ int d=x-y; if(d<0) d+=mod; return d; } void butterfly(vector<int>& a){ int n = (int)a.size(); int h = 0; // (n-1).bit_length() while((1<<h) < n) ++h; int LEN = 0; while(LEN < h){ if(h - LEN == 1){ int p = 1 << (h - LEN - 1); int rot = 1; for(unsigned s=0; s<(1u<<LEN); s++){ int offset = int(s) << (h - LEN); for(int i=0;i<p;i++){ int l = a[i + offset]; int r = int(1LL * a[i + offset + p] * rot % mod); a[i + offset] = addmod(l, r, mod); a[i + offset + p] = submod(l, r, mod); } // idx = ( (~s & -~s).bit_length() - 1 ) <=> ctz(~s) int idx = __builtin_ctz(~s); rot = int(1LL * rot * rate2[idx] % mod); } LEN += 1; } else { int p = 1 << (h - LEN - 2); int rot = 1; int imag = root[2]; for(unsigned s=0; s<(1u<<LEN); s++){ int rot2 = int(1LL*rot*rot % mod); int rot3 = int(1LL*rot2*rot % mod); int offset = int(s) << (h - LEN); for(int i=0;i<p;i++){ int a0 = a[i + offset]; int a1 = int(1LL * a[i + offset + p ] * rot % mod); int a2 = int(1LL * a[i + offset + 2 * p] * rot2 % mod); int a3 = int(1LL * a[i + offset + 3 * p] * rot3 % mod); int a1na3 = submod(a1, a3, mod); long long a1na3imag = 1LL * a1na3 * imag % mod; a[i + offset] = (((a0 + a2) >= mod ? a0 + a2 - mod : a0 + a2) + ((a1 + a3) >= mod ? a1 + a3 - mod : a1 + a3)); if(a[i + offset] >= mod) a[i + offset] -= mod; int tmp = ((a0 + a2) >= mod ? a0 + a2 - mod : a0 + a2); int t2 = ((a1 + a3) >= mod ? a1 + a3 - mod : a1 + a3); a[i + offset + p] = submod(tmp, t2, mod); int t3 = submod(a0, a2, mod); a[i + offset + 2 * p] = int((t3 + a1na3imag) % mod); if(a[i + offset + 2 * p] < 0) a[i + offset + 2 * p] += mod; a[i + offset + 3 * p] = int((t3 - a1na3imag) % mod); if(a[i + offset + 3 * p] < 0) a[i + offset + 3 * p] += mod; } int idx = __builtin_ctz(~s); rot = int(1LL * rot * rate3[idx] % mod); } LEN += 2; } } } void butterfly_inv(vector<int>& a){ int n = (int)a.size(); int h = 0; while((1<<h) < n) ++h; int LEN = h; while(LEN){ if(LEN == 1){ int p = 1 << (h - LEN); int irot = 1; for(unsigned s=0; s<(1u<<(LEN-1)); s++){ int offset = int(s) << (h - LEN + 1); for(int i=0;i<p;i++){ int l = a[i + offset]; int r = a[i + offset + p]; a[i + offset] = addmod(l, r, mod); a[i + offset + p] = int(1LL * submod(l, r, mod) * irot % mod); } int idx = __builtin_ctz(~s); irot = int(1LL * irot * irate2[idx] % mod); } LEN -= 1; } else { int p = 1 << (h - LEN); int irot = 1; int iimag = iroot[2]; for(unsigned s=0; s<(1u<<(LEN-2)); s++){ int irot2 = int(1LL*irot*irot % mod); int irot3 = int(1LL*irot*irot2 % mod); int offset = int(s) << (h - LEN + 2); for(int i=0;i<p;i++){ int a0 = a[i + offset]; int a1 = a[i + offset + p]; int a2 = a[i + offset + 2 * p]; int a3 = a[i + offset + 3 * p]; int a2na3 = submod(a2, a3, mod); int a2na3iimag = int(1LL * a2na3 * iimag % mod); a[i + offset] = (((a0 + a1) >= mod ? a0 + a1 - mod : a0 + a1) + ((a2 + a3) >= mod ? a2 + a3 - mod : a2 + a3)); if(a[i + offset] >= mod) a[i + offset] -= mod; int t1 = submod(a0, a1, mod); a[i + offset + p] = int(1LL * (t1 + a2na3iimag) % mod); if(a[i + offset + p] < 0) a[i + offset + p] += mod; a[i + offset + p] = int(1LL * a[i + offset + p] * irot % mod); int t2 = submod((a0 + a1) >= mod ? a0 + a1 - mod : a0 + a1, (a2 + a3) >= mod ? a2 + a3 - mod : a2 + a3, mod); a[i + offset + 2 * p] = int(1LL * t2 * irot2 % mod); int t3 = submod(a0, a1, mod); int t4 = submod(t3, a2na3iimag, mod); a[i + offset + 3 * p] = int(1LL * t4 * irot3 % mod); } int idx = __builtin_ctz(~s); irot = int(1LL * irot * irate3[idx] % mod); } LEN -= 2; } } } static int ceil_pow2(int n){ int x=0; while((1<<x) < n) ++x; return x; } vector<int> convolution(vector<int> a, vector<int> b){ int n = (int)a.size(), m = (int)b.size(); if(n==0 || m==0) return {}; if(min(n,m) <= 40){ vector<int> res(n+m-1, 0); for(int i=0;i<n;i++) if(a[i]){ for(int j=0;j<m;j++){ res[i+j] = (res[i+j] + (int)(1LL*a[i]*b[j] % mod)) % mod; } } return res; } int z = 1 << ceil_pow2(n + m - 1); a.resize(z, 0); b.resize(z, 0); butterfly(a); butterfly(b); vector<int> c(z); for(int i=0;i<z;i++) c[i] = (int)(1LL*a[i]*b[i] % mod); butterfly_inv(c); int iz = (int)modpow_ll(z, mod-2, mod); c.resize(n+m-1); for(int i=0;i<n+m-1;i++) c[i] = (int)(1LL*c[i]*iz % mod); return c; } }; static inline long long modpow(long long a, long long e, long long mod){ long long r=1%mod; while(e){ if(e&1) r = r*a % mod; a = a*a % mod; e >>= 1; } return r; } // Helper: Python-like v[idx] supporting negative idx (only reads; Python uses v[-1]) template<class T> static inline T vec_at_with_neg(const vector<T>& v, long long idx){ long long n = (long long)v.size(); if(idx >= 0) return v[(size_t)idx]; long long j = n + idx; // idx is negative if(j < 0) j = 0; // safety (shouldn't happen in this code path) return v[(size_t)j]; } int main(){ ios::sync_with_stdio(false); cin.tie(nullptr); int N; long long K; if(!(cin >> N >> K)) return 0; vector<int> A(N); for(int i=0;i<N;i++) cin >> A[i]; const int mod = 998244353; vector<long long> v(N+1, 0); int t = 0; long long z = 1; v[0] = 1; vector<long long> dp(N+1, 0); dp[0] = 1; vector<long long> p(N+1, 0); vector<vector<int>> G(N+1); G[0].push_back(0); for(int i=0;i<N;i++){ int x = A[i]; z *= x; if(z == 8){ t += 1; z = 1; } p[i+1] = t; if(z == 1){ dp[i+1] += vec_at_with_neg(v, t-1); // v[t-1] (Python allows -1) v[t] += dp[i+1]; v[t] %= mod; G[t].push_back(i+1); } } if(z != 1){ cout << 0 << '\n'; return 0; } G[t] = { N }; G[0] = { 0 }; int F = t; vector<long long> cp(N+1, 0); cp[N] = 1; vector<long long> v2(N+1, 0); t = 0; z = 1; v2[0] = 1; for(int i=N-1;i>=0;i--){ int x = A[i]; z *= x; if(z == 8){ t += 1; z = 1; } if(z == 1){ cp[i] += vec_at_with_neg(v2, t-1); // v2[t-1] v2[t] += cp[i]; v2[t] %= mod; } } long long result = 0; FFT Z(mod); for(int e=1; e<=F; e++){ int d = G[e].back() - G[e-1].front() + 1; vector<int> u(d, 0), w(d, 0); for(int pos : G[e-1]){ int a = pos - G[e-1].front(); u[a] = (int)(cp[pos] % mod); w[G[e].back() - pos] = (int)(dp[pos] % mod); } for(int pos : G[e]){ int a = pos - G[e-1].front(); u[a] = (int)(cp[pos] % mod); w[G[e].back() - pos] = (int)(dp[pos] % mod); } vector<int> h = Z.convolution(u, w); for(int y=d; y<(int)h.size(); y++){ long long base = y - (d - 1); long long add = modpow(base % mod, K, mod); result = (result + add * h[y]) % mod; } } for(int e=0; e<=F; e++){ int d = G[e].back() - G[e].front() + 1; vector<int> u(d, 0), w(d, 0); for(int pos : G[e]){ int a = pos - G[e].front(); u[a] = (int)(cp[pos] % mod); w[G[e].back() - pos] = (int)(dp[pos] % mod); } vector<int> h = Z.convolution(u, w); for(int y=d; y<(int)h.size(); y++){ long long base = y - (d - 1); long long term = modpow(base % mod, K, mod) * h[y] % mod; if(e==0 || e==F){ result = (result - term) % mod; }else{ result = (result - (2*term)%mod) % mod; } if(result < 0) result += mod; } } cout << (result % mod + mod) % mod << '\n'; return 0; }