結果

問題 No.2313 Product of Subsequence (hard)
ユーザー lam6er
提出日時 2025-03-31 17:50:53
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,367 bytes
コンパイル時間 296 ms
コンパイル使用メモリ 82,160 KB
実行使用メモリ 268,356 KB
最終ジャッジ日時 2025-03-31 17:52:07
合計ジャッジ時間 9,672 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 10 TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def factorize(n):
    factors = {}
    i = 2
    while i * i <= n:
        if n % i == 0:
            cnt = 0
            while n % i == 0:
                cnt +=1
                n //=i
            factors[i] = cnt
        i +=1
    if n >1:
        factors[n] =1
    return factors

def main():
    import sys
    N, K = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))
    
    if K ==1:
        print((pow(2, N, MOD) -1) % MOD)
        return
    
    k_factors = factorize(K)
    primes = list(k_factors.items())
    m = len(primes)
    
    core = []
    M =0
    
    for a in A:
        contributions = []
        valid = False
        for (p, e) in primes:
            cnt =0
            temp = a
            while temp % p ==0:
                cnt +=1
                temp //=p
            contributions.append(cnt)
            if cnt >0:
                valid = True
        if valid:
            core.append(contributions)
        else:
            M +=1
    
    sum_ans =0
    for mask in range(1 << m):
        bits = bin(mask).count('1')
        T = []
        Ts = []
        e_list = []
        for i in range(m):
            if (mask >> i) &1:
                T.append(i)
                Ts.append(i)
                e_list.append(primes[i][1])
        T_set = set(Ts)
        
        filtered = []
        zero_count =0
        for contri in core:
            include = True
            for idx in Ts:
                if contri[idx] >= primes[idx][1]:
                    include =False
                    break
            if not include:
                continue
            is_zero = True
            new_contri = []
            for idx in Ts:
                c = contri[idx]
                new_contri.append(c)
                if c !=0:
                    is_zero = False
            if is_zero:
                zero_count +=1
            else:
                filtered.append(new_contri)
        
        state = {}
        initial = tuple([0]*len(Ts))
        state[initial] =1
        
        for f in filtered:
            new_state = {}
            for s in state:
                cnt = state[s]
                new_s = list(s)
                valid = True
                for i in range(len(new_s)):
                    new_s[i] += f[i]
                    if new_s[i] >= e_list[i]:
                        valid = False
                        break
                if valid:
                    new_s_tuple = tuple(new_s)
                    if new_s_tuple in new_state:
                        new_state[new_s_tuple] = (new_state[new_s_tuple] + cnt) % MOD
                    else:
                        new_state[new_s_tuple] = cnt % MOD
                if s in new_state:
                    new_state[s] = (new_state[s] + cnt) % MOD
                else:
                    new_state[s] = cnt % MOD
            state = new_state
        
        total =0
        for v in state.values():
            total = (total + v) % MOD
        total = total * pow(2, zero_count, MOD) % MOD
        
        sign = (-1)**(bits)
        if sign ==1:
            sum_ans = (sum_ans + total) % MOD
        else:
            sum_ans = (sum_ans - total) % MOD
    
    ans = sum_ans % MOD
    ans = ans * pow(2, M, MOD) % MOD
    print(ans)

if __name__ == "__main__":
    main()
0