結果

問題 No.3247 Multiplication 8 2
ユーザー ゼット
提出日時 2025-08-24 01:41:53
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
TLE  
実行時間 -
コード長 20,857 bytes
コンパイル時間 3,653 ms
コンパイル使用メモリ 296,784 KB
実行使用メモリ 104,188 KB
最終ジャッジ日時 2025-08-24 01:43:16
合計ジャッジ時間 82,409 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 22 TLE * 6
権限があれば一括ダウンロードができます

ソースコード

diff #

/*
class FFT:
    def primitive_root_constexpr(self, m):
        if m == 2:
            return 1
        if m == 167772161:
            return 3
        if m == 469762049:
            return 3
        if m == 754974721:
            return 11
        if m == 998244353:
            return 3
        divs = [0] * 20
        divs[0] = 2
        cnt = 1
        x = (m - 1) // 2
        while x % 2 == 0:
            x //= 2
        i = 3
        while i * i <= x:
            if x % i == 0:
                divs[cnt] = i
                cnt += 1
                while x % i == 0:
                    x //= i
            i += 2
        if x > 1:
            divs[cnt] = x
            cnt += 1
        g = 2
        while 1:
            ok = True
            for i in range(cnt):
                if pow(g, (m - 1) // divs[i], m) == 1:
                    ok = False
                    break
            if ok:
                return g
            g += 1

    def bsf(self, x):
        res = 0
        while x % 2 == 0:
            res += 1
            x //= 2
        return res

    rank2 = 0
    root = []
    iroot = []
    rate2 = []
    irate2 = []
    rate3 = []
    irate3 = []

    def __init__(self, MOD):
        self.mod = MOD
        self.g = self.primitive_root_constexpr(self.mod)
        self.rank2 = self.bsf(self.mod - 1)
        self.root = [0 for i in range(self.rank2 + 1)]
        self.iroot = [0 for i in range(self.rank2 + 1)]
        self.rate2 = [0 for i in range(self.rank2)]
        self.irate2 = [0 for i in range(self.rank2)]
        self.rate3 = [0 for i in range(self.rank2 - 1)]
        self.irate3 = [0 for i in range(self.rank2 - 1)]
        self.root[self.rank2] = pow(self.g, (self.mod - 1) >> self.rank2, self.mod)
        self.iroot[self.rank2] = pow(self.root[self.rank2], self.mod - 2, self.mod)
        for i in range(self.rank2 - 1, -1, -1):
            self.root[i] = (self.root[i + 1] ** 2) % self.mod
            self.iroot[i] = (self.iroot[i + 1] ** 2) % self.mod
        prod = 1
        iprod = 1
        for i in range(self.rank2 - 1):
            self.rate2[i] = (self.root[i + 2] * prod) % self.mod
            self.irate2[i] = (self.iroot[i + 2] * iprod) % self.mod
            prod = (prod * self.iroot[i + 2]) % self.mod
            iprod = (iprod * self.root[i + 2]) % self.mod
        prod = 1
        iprod = 1
        for i in range(self.rank2 - 2):
            self.rate3[i] = (self.root[i + 3] * prod) % self.mod
            self.irate3[i] = (self.iroot[i + 3] * iprod) % self.mod
            prod = (prod * self.iroot[i + 3]) % self.mod
            iprod = (iprod * self.root[i + 3]) % self.mod

    def butterfly(self, a):
        n = len(a)
        h = (n - 1).bit_length()

        LEN = 0
        while LEN < h:
            if h - LEN == 1:
                p = 1 << (h - LEN - 1)
                rot = 1
                for s in range(1 << LEN):
                    offset = s << (h - LEN)
                    for i in range(p):
                        l = a[i + offset]
                        r = a[i + offset + p] * rot
                        a[i + offset] = (l + r) % self.mod
                        a[i + offset + p] = (l - r) % self.mod
                    rot *= self.rate2[(~s & -~s).bit_length() - 1]
                    rot %= self.mod
                LEN += 1
            else:
                p = 1 << (h - LEN - 2)
                rot = 1
                imag = self.root[2]
                for s in range(1 << LEN):
                    rot2 = (rot * rot) % self.mod
                    rot3 = (rot2 * rot) % self.mod
                    offset = s << (h - LEN)
                    for i in range(p):
                        a0 = a[i + offset]
                        a1 = a[i + offset + p] * rot
                        a2 = a[i + offset + 2 * p] * rot2
                        a3 = a[i + offset + 3 * p] * rot3
                        a1na3imag = (a1 - a3) % self.mod * imag
                        a[i + offset] = (a0 + a2 + a1 + a3) % self.mod
                        a[i + offset + p] = (a0 + a2 - a1 - a3) % self.mod
                        a[i + offset + 2 * p] = (a0 - a2 + a1na3imag) % self.mod
                        a[i + offset + 3 * p] = (a0 - a2 - a1na3imag) % self.mod
                    rot *= self.rate3[(~s & -~s).bit_length() - 1]
                    rot %= self.mod
                LEN += 2

    def butterfly_inv(self, a):
        n = len(a)
        h = (n - 1).bit_length()
        LEN = h
        while LEN:
            if LEN == 1:
                p = 1 << (h - LEN)
                irot = 1
                for s in range(1 << (LEN - 1)):
                    offset = s << (h - LEN + 1)
                    for i in range(p):
                        l = a[i + offset]
                        r = a[i + offset + p]
                        a[i + offset] = (l + r) % self.mod
                        a[i + offset + p] = (l - r) * irot % self.mod
                    irot *= self.irate2[(~s & -~s).bit_length() - 1]
                    irot %= self.mod
                LEN -= 1
            else:
                p = 1 << (h - LEN)
                irot = 1
                iimag = self.iroot[2]
                for s in range(1 << (LEN - 2)):
                    irot2 = (irot * irot) % self.mod
                    irot3 = (irot * irot2) % self.mod
                    offset = s << (h - LEN + 2)
                    for i in range(p):
                        a0 = a[i + offset]
                        a1 = a[i + offset + p]
                        a2 = a[i + offset + 2 * p]
                        a3 = a[i + offset + 3 * p]
                        a2na3iimag = (a2 - a3) * iimag % self.mod
                        a[i + offset] = (a0 + a1 + a2 + a3) % self.mod
                        a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % self.mod
                        a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2 % self.mod
                        a[i + offset + 3 * p] = (
                            (a0 - a1 - a2na3iimag) * irot3 % self.mod
                        )
                    irot *= self.irate3[(~s & -~s).bit_length() - 1]
                    irot %= self.mod
                LEN -= 2

    def convolution(self, a, b):
        n = len(a)
        m = len(b)
        if not (a) or not (b):
            return []
        if min(n, m) <= 40:
            res = [0] * (n + m - 1)
            for i in range(n):
                for j in range(m):
                    res[i + j] += a[i] * b[j]
                    res[i + j] %= self.mod
            return res
        z = 1 << ((n + m - 2).bit_length())
        a = a + [0] * (z - n)
        b = b + [0] * (z - m)
        self.butterfly(a)
        self.butterfly(b)
        c = [(a[i] * b[i]) % self.mod for i in range(z)]
        self.butterfly_inv(c)
        iz = pow(z, self.mod - 2, self.mod)
        for i in range(n + m - 1):
            c[i] = (c[i] * iz) % self.mod
        return c[: n + m - 1]
N,K=map(int,input().split())
A=list(map(int,input().split()))
v=[0]*(N+1)
t=0
z=1
v[0]=1
dp=[0]*(N+1)
dp[0]=1
mod=998244353
p=[0]*(N+1)
G=[[] for i in range(N+1)]
G[0].append(0)
for i in range(N):
  x=A[i]
  z*=x
  if z==8:
    t+=1
    z=1
  p[i+1]=t
  if z==1:
    dp[i+1]+=v[t-1]
    v[t]+=dp[i+1]
    v[t]%=mod
    G[t].append(i+1)
if z!=1:
  print(0)
  exit()
G[t]=[N]
G[0]=[0]
F=t
cp=[0]*(N+1)
cp[N]=1
v2=[0]*(N+1)
t=0
z=1
v2[0]=1
for i in range(N-1,-1,-1):
  x=A[i]
  z*=x
  if z==8:
    t+=1
    z=1
  if z==1:
    cp[i]+=v2[t-1]
    v2[t]+=cp[i]
    v2[t]%=mod
result=0
Z=FFT(mod)
for e in range(1,F+1):
  d=G[e][-1]-G[e-1][0]+1
  u=[0]*d
  v=[0]*d
  for pos in G[e-1]:
    a=pos-G[e-1][0]
    u[a]=cp[pos]
    v[G[e][-1]-pos]=dp[pos]
  for pos in G[e]:
    a=pos-G[e-1][0]
    u[a]=cp[pos]
    v[G[e][-1]-pos]=dp[pos]
  h=Z.convolution(u,v)
  for y in range(d,len(h)):
    result+=pow((y-(d-1)),K,mod)*h[y]
    result%=mod
for e in range(F+1):
  d=G[e][-1]-G[e][0]+1
  u=[0]*d
  v=[0]*d
  for pos in G[e]:
    a=pos-G[e][0]
    u[a]=cp[pos]
    v[G[e][-1]-pos]=dp[pos]
  h=Z.convolution(u,v)
  for y in range(d,len(h)):
    if e==0 or e==F:
      result-=pow((y-(d-1)),K,mod)*h[y]
      result%=mod
    else:
      result-=2*pow((y-(d-1)),K,mod)*h[y]
      result%=mod
print(result)
*/

#include <bits/stdc++.h>
using namespace std;

struct FFT {
    int mod, g, rank2;
    vector<int> root, iroot, rate2, irate2, rate3, irate3;

    static long long modpow_ll(long long a, long long e, long long m){
        long long r=1% m;
        while(e){
            if(e&1) r = (r*a) % m;
            a = (a*a) % m;
            e >>= 1;
        }
        return r;
    }

    int primitive_root_constexpr(int m){
        if (m == 2) return 1;
        if (m == 167772161) return 3;
        if (m == 469762049) return 3;
        if (m == 754974721) return 11;
        if (m == 998244353) return 3;
        vector<int> divs(20);
        divs[0]=2; int cnt=1;
        int x=(m-1)/2;
        while(x%2==0) x/=2;
        for(int i=3; 1LL*i*i<=x; i+=2){
            if(x%i==0){
                divs[cnt++]=i;
                while(x%i==0) x/=i;
            }
        }
        if(x>1) divs[cnt++]=x;
        for(int g=2;;g++){
            bool ok=true;
            for(int i=0;i<cnt;i++){
                if(modpow_ll(g,(m-1)/divs[i],m)==1){
                    ok=false; break;
                }
            }
            if(ok) return g;
        }
    }

    static int bsf(int x){
        int res=0;
        while(x%2==0){ res++; x/=2; }
        return res;
    }

    FFT(int MOD){
        mod = MOD;
        g = primitive_root_constexpr(mod);
        rank2 = bsf(mod-1);
        root.assign(rank2+1,0);
        iroot.assign(rank2+1,0);
        rate2.assign(rank2,0);
        irate2.assign(rank2,0);
        rate3.assign(rank2-1,0);
        irate3.assign(rank2-1,0);
        root[rank2]  = (int)modpow_ll(g, (mod-1)>>rank2, mod);
        iroot[rank2] = (int)modpow_ll(root[rank2], mod-2, mod);
        for(int i=rank2-1;i>=0;i--){
            root[i]  = int(1LL*root[i+1]*root[i+1] % mod);
            iroot[i] = int(1LL*iroot[i+1]*iroot[i+1] % mod);
        }
        long long prod=1, iprod=1;
        for(int i=0;i<rank2-1;i++){
            rate2[i]  = int(1LL*root[i+2]*prod % mod);
            irate2[i] = int(1LL*iroot[i+2]*iprod % mod);
            prod  = prod * iroot[i+2] % mod;
            iprod = iprod * root[i+2] % mod;
        }
        prod=1; iprod=1;
        for(int i=0;i<rank2-2;i++){
            rate3[i]  = int(1LL*root[i+3]*prod % mod);
            irate3[i] = int(1LL*iroot[i+3]*iprod % mod);
            prod  = prod * iroot[i+3] % mod;
            iprod = iprod * root[i+3] % mod;
        }
    }

    static inline int addmod(int x, int y, int mod){ int s=x+y; if(s>=mod) s-=mod; return s; }
    static inline int submod(int x, int y, int mod){ int d=x-y; if(d<0) d+=mod; return d; }

    void butterfly(vector<int>& a){
        int n = (int)a.size();
        int h = 0; // (n-1).bit_length()
        while((1<<h) < n) ++h;

        int LEN = 0;
        while(LEN < h){
            if(h - LEN == 1){
                int p = 1 << (h - LEN - 1);
                int rot = 1;
                for(unsigned s=0; s<(1u<<LEN); s++){
                    int offset = int(s) << (h - LEN);
                    for(int i=0;i<p;i++){
                        int l = a[i + offset];
                        int r = int(1LL * a[i + offset + p] * rot % mod);
                        a[i + offset]       = addmod(l, r, mod);
                        a[i + offset + p]   = submod(l, r, mod);
                    }
                    // idx = ( (~s & -~s).bit_length() - 1 )  <=>  ctz(~s)
                    int idx = __builtin_ctz(~s);
                    rot = int(1LL * rot * rate2[idx] % mod);
                }
                LEN += 1;
            } else {
                int p = 1 << (h - LEN - 2);
                int rot = 1;
                int imag = root[2];
                for(unsigned s=0; s<(1u<<LEN); s++){
                    int rot2 = int(1LL*rot*rot % mod);
                    int rot3 = int(1LL*rot2*rot % mod);
                    int offset = int(s) << (h - LEN);
                    for(int i=0;i<p;i++){
                        int a0 = a[i + offset];
                        int a1 = int(1LL * a[i + offset + p    ] * rot  % mod);
                        int a2 = int(1LL * a[i + offset + 2 * p] * rot2 % mod);
                        int a3 = int(1LL * a[i + offset + 3 * p] * rot3 % mod);
                        int a1na3 = submod(a1, a3, mod);
                        long long a1na3imag = 1LL * a1na3 * imag % mod;

                        a[i + offset]         = (((a0 + a2) >= mod ? a0 + a2 - mod : a0 + a2) + ((a1 + a3) >= mod ? a1 + a3 - mod : a1 + a3));
                        if(a[i + offset] >= mod) a[i + offset] -= mod;

                        int tmp = ((a0 + a2) >= mod ? a0 + a2 - mod : a0 + a2);
                        int t2  = ((a1 + a3) >= mod ? a1 + a3 - mod : a1 + a3);
                        a[i + offset + p]     = submod(tmp, t2, mod);

                        int t3 = submod(a0, a2, mod);
                        a[i + offset + 2 * p] = int((t3 + a1na3imag) % mod);
                        if(a[i + offset + 2 * p] < 0) a[i + offset + 2 * p] += mod;
                        a[i + offset + 3 * p] = int((t3 - a1na3imag) % mod);
                        if(a[i + offset + 3 * p] < 0) a[i + offset + 3 * p] += mod;
                    }
                    int idx = __builtin_ctz(~s);
                    rot = int(1LL * rot * rate3[idx] % mod);
                }
                LEN += 2;
            }
        }
    }

    void butterfly_inv(vector<int>& a){
        int n = (int)a.size();
        int h = 0; while((1<<h) < n) ++h;
        int LEN = h;
        while(LEN){
            if(LEN == 1){
                int p = 1 << (h - LEN);
                int irot = 1;
                for(unsigned s=0; s<(1u<<(LEN-1)); s++){
                    int offset = int(s) << (h - LEN + 1);
                    for(int i=0;i<p;i++){
                        int l = a[i + offset];
                        int r = a[i + offset + p];
                        a[i + offset]       = addmod(l, r, mod);
                        a[i + offset + p]   = int(1LL * submod(l, r, mod) * irot % mod);
                    }
                    int idx = __builtin_ctz(~s);
                    irot = int(1LL * irot * irate2[idx] % mod);
                }
                LEN -= 1;
            } else {
                int p = 1 << (h - LEN);
                int irot = 1;
                int iimag = iroot[2];
                for(unsigned s=0; s<(1u<<(LEN-2)); s++){
                    int irot2 = int(1LL*irot*irot % mod);
                    int irot3 = int(1LL*irot*irot2 % mod);
                    int offset = int(s) << (h - LEN + 2);
                    for(int i=0;i<p;i++){
                        int a0 = a[i + offset];
                        int a1 = a[i + offset + p];
                        int a2 = a[i + offset + 2 * p];
                        int a3 = a[i + offset + 3 * p];
                        int a2na3 = submod(a2, a3, mod);
                        int a2na3iimag = int(1LL * a2na3 * iimag % mod);

                        a[i + offset]         = (((a0 + a1) >= mod ? a0 + a1 - mod : a0 + a1) + ((a2 + a3) >= mod ? a2 + a3 - mod : a2 + a3));
                        if(a[i + offset] >= mod) a[i + offset] -= mod;

                        int t1 = submod(a0, a1, mod);
                        a[i + offset + p]     = int(1LL * (t1 + a2na3iimag) % mod);
                        if(a[i + offset + p] < 0) a[i + offset + p] += mod;
                        a[i + offset + p]     = int(1LL * a[i + offset + p] * irot % mod);

                        int t2 = submod((a0 + a1) >= mod ? a0 + a1 - mod : a0 + a1,
                                        (a2 + a3) >= mod ? a2 + a3 - mod : a2 + a3, mod);
                        a[i + offset + 2 * p] = int(1LL * t2 * irot2 % mod);

                        int t3 = submod(a0, a1, mod);
                        int t4 = submod(t3, a2na3iimag, mod);
                        a[i + offset + 3 * p] = int(1LL * t4 * irot3 % mod);
                    }
                    int idx = __builtin_ctz(~s);
                    irot = int(1LL * irot * irate3[idx] % mod);
                }
                LEN -= 2;
            }
        }
    }

    static int ceil_pow2(int n){
        int x=0; while((1<<x) < n) ++x; return x;
    }

    vector<int> convolution(vector<int> a, vector<int> b){
        int n = (int)a.size(), m = (int)b.size();
        if(n==0 || m==0) return {};
        if(min(n,m) <= 40){
            vector<int> res(n+m-1, 0);
            for(int i=0;i<n;i++) if(a[i]){
                for(int j=0;j<m;j++){
                    res[i+j] = (res[i+j] + (int)(1LL*a[i]*b[j] % mod)) % mod;
                }
            }
            return res;
        }
        int z = 1 << ceil_pow2(n + m - 1);
        a.resize(z, 0);
        b.resize(z, 0);
        butterfly(a);
        butterfly(b);
        vector<int> c(z);
        for(int i=0;i<z;i++) c[i] = (int)(1LL*a[i]*b[i] % mod);
        butterfly_inv(c);
        int iz = (int)modpow_ll(z, mod-2, mod);
        c.resize(n+m-1);
        for(int i=0;i<n+m-1;i++) c[i] = (int)(1LL*c[i]*iz % mod);
        return c;
    }
};

static inline long long modpow(long long a, long long e, long long mod){
    long long r=1%mod;
    while(e){
        if(e&1) r = r*a % mod;
        a = a*a % mod;
        e >>= 1;
    }
    return r;
}

// Helper: Python-like v[idx] supporting negative idx (only reads; Python uses v[-1])
template<class T>
static inline T vec_at_with_neg(const vector<T>& v, long long idx){
    long long n = (long long)v.size();
    if(idx >= 0) return v[(size_t)idx];
    long long j = n + idx; // idx is negative
    if(j < 0) j = 0;       // safety (shouldn't happen in this code path)
    return v[(size_t)j];
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N;
    long long K;
    if(!(cin >> N >> K)) return 0;
    vector<int> A(N);
    for(int i=0;i<N;i++) cin >> A[i];

    const int mod = 998244353;

    vector<long long> v(N+1, 0);
    int t = 0;
    long long z = 1;
    v[0] = 1;
    vector<long long> dp(N+1, 0);
    dp[0] = 1;
    vector<long long> p(N+1, 0);
    vector<vector<int>> G(N+1);
    G[0].push_back(0);

    for(int i=0;i<N;i++){
        int x = A[i];
        z *= x;
        if(z == 8){
            t += 1;
            z = 1;
        }
        p[i+1] = t;
        if(z == 1){
            dp[i+1] += vec_at_with_neg(v, t-1);         // v[t-1] (Python allows -1)
            v[t] += dp[i+1];
            v[t] %= mod;
            G[t].push_back(i+1);
        }
    }

    if(z != 1){
        cout << 0 << '\n';
        return 0;
    }
    G[t] = { N };
    G[0] = { 0 };
    int F = t;

    vector<long long> cp(N+1, 0);
    cp[N] = 1;
    vector<long long> v2(N+1, 0);
    t = 0; z = 1; v2[0] = 1;

    for(int i=N-1;i>=0;i--){
        int x = A[i];
        z *= x;
        if(z == 8){
            t += 1;
            z = 1;
        }
        if(z == 1){
            cp[i] += vec_at_with_neg(v2, t-1);          // v2[t-1]
            v2[t] += cp[i];
            v2[t] %= mod;
        }
    }

    long long result = 0;
    FFT Z(mod);

    for(int e=1; e<=F; e++){
        int d = G[e].back() - G[e-1].front() + 1;
        vector<int> u(d, 0), w(d, 0);
        for(int pos : G[e-1]){
            int a = pos - G[e-1].front();
            u[a] = (int)(cp[pos] % mod);
            w[G[e].back() - pos] = (int)(dp[pos] % mod);
        }
        for(int pos : G[e]){
            int a = pos - G[e-1].front();
            u[a] = (int)(cp[pos] % mod);
            w[G[e].back() - pos] = (int)(dp[pos] % mod);
        }
        vector<int> h = Z.convolution(u, w);
        for(int y=d; y<(int)h.size(); y++){
            long long base = y - (d - 1);
            long long add = modpow(base % mod, K, mod);
            result = (result + add * h[y]) % mod;
        }
    }

    for(int e=0; e<=F; e++){
        int d = G[e].back() - G[e].front() + 1;
        vector<int> u(d, 0), w(d, 0);
        for(int pos : G[e]){
            int a = pos - G[e].front();
            u[a] = (int)(cp[pos] % mod);
            w[G[e].back() - pos] = (int)(dp[pos] % mod);
        }
        vector<int> h = Z.convolution(u, w);
        for(int y=d; y<(int)h.size(); y++){
            long long base = y - (d - 1);
            long long term = modpow(base % mod, K, mod) * h[y] % mod;
            if(e==0 || e==F){
                result = (result - term) % mod;
            }else{
                result = (result - (2*term)%mod) % mod;
            }
            if(result < 0) result += mod;
        }
    }

    cout << (result % mod + mod) % mod << '\n';
    return 0;
}
0