結果

問題 No.2313 Product of Subsequence (hard)
ユーザー lam6er
提出日時 2025-04-15 23:54:31
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,927 bytes
コンパイル時間 188 ms
コンパイル使用メモリ 82,112 KB
実行使用メモリ 286,744 KB
最終ジャッジ日時 2025-04-15 23:55:41
合計ジャッジ時間 7,351 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 10 TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

MOD = 998244353

def factorize(k):
    factors = {}
    i = 2
    while i * i <= k:
        while k % i == 0:
            factors[i] = factors.get(i, 0) + 1
            k //= i
        i += 1
    if k > 1:
        factors[k] = 1
    return factors

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, K = int(input[ptr]), int(input[ptr+1])
    ptr +=2
    A = list(map(int, input[ptr:ptr+N]))
    
    if K == 1:
        print((pow(2, N, MOD) - 1) % MOD)
        return
    
    factors = factorize(K)
    primes = list(factors.keys())
    m = len(primes)
    e_list = [factors[p] for p in primes]
    
    elements = []
    for a in A:
        exps = []
        valid_mask = 0
        for i in range(m):
            p = primes[i]
            e = 0
            while a % p == 0:
                e +=1
                a //=p
            exps.append(e)
            if e < e_list[i]:
                valid_mask |= (1 << i)
        elements.append( (valid_mask, exps) )
    
    pow2 = [1] * (N+1)
    for i in range(1, N+1):
        pow2[i] = (pow2[i-1] * 2) % MOD
    total = (pow2[N] - 1) % MOD
    
    from itertools import combinations
    
    inclusion_exclusion = 0
    for mask in range(1, 1 << m):
        S = []
        S_indices = []
        for i in range(m):
            if (mask >> i) & 1:
                S.append(primes[i])
                S_indices.append(i)
        k = len(S_indices)
        current_e_list = [e_list[i] for i in S_indices]
        allowed_exponents = []
        for valid_mask, exps in elements:
            if (valid_mask & mask) == mask:
                selected_exps = [exps[i] for i in S_indices]
                allowed_exponents.append(selected_exps)
        
        if not allowed_exponents:
            continue
        
        m_subs = len(S_indices)
        max_sums = [current_e_list[i] - 1 for i in range(m_subs)]
        dp = defaultdict(int)
        initial_state = tuple([0]*m_subs)
        dp[initial_state] = 1
        
        for exps in allowed_exponents:
            new_dp = defaultdict(int)
            for state, cnt in dp.items():
                new_dp[state] = (new_dp[state] + cnt) % MOD
                new_state = list(state)
                valid = True
                for i in range(m_subs):
                    new_state[i] += exps[i]
                    if new_state[i] > max_sums[i]:
                        valid = False
                        break
                if valid:
                    new_state_t = tuple(new_state)
                    new_dp[new_state_t] = (new_dp[new_state_t] + cnt) % MOD
            dp = new_dp
        
        count = (sum(dp.values()) - 1) % MOD
        sign = (-1) ** (k + 1)
        inclusion_exclusion = (inclusion_exclusion + sign * count) % MOD
    
    ans = (total - inclusion_exclusion) % MOD
    print(ans)

if __name__ == '__main__':
    main()
0