結果

問題 No.2313 Product of Subsequence (hard)
ユーザー lam6er
提出日時 2025-04-15 23:42:30
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,285 bytes
コンパイル時間 132 ms
コンパイル使用メモリ 82,644 KB
実行使用メモリ 262,548 KB
最終ジャッジ日時 2025-04-15 23:44:50
合計ジャッジ時間 12,550 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 10 TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    from sys import stdin
    from collections import defaultdict

    N, K = map(int, stdin.readline().split())
    A = list(map(int, stdin.readline().split()))

    if K == 1:
        print((pow(2, N, MOD) - 1) % MOD)
        return

    def factorize(k):
        factors = {}
        i = 2
        while i * i <= k:
            while k % i == 0:
                factors[i] = factors.get(i, 0) + 1
                k = k // i
            i += 1
        if k > 1:
            factors[k] = 1
        return factors

    k_factors = factorize(K)
    if not k_factors:
        print((pow(2, N, MOD) - 1) % MOD)
        return

    primes = sorted(k_factors.keys())
    m = len(primes)
    required = [k_factors[p] for p in primes]

    elements = []
    for a in A:
        exponents = []
        a_remaining = a
        for p in primes:
            e = 0
            while a_remaining % p == 0:
                e += 1
                a_remaining = a_remaining // p
            exponents.append(e)
        mask = 0
        for i in range(m):
            if exponents[i] <= required[i] - 1:
                mask |= (1 << i)
        non_zero_mask = 0
        for i in range(m):
            if exponents[i] > 0:
                non_zero_mask |= (1 << i)
        elements.append((exponents, mask, non_zero_mask))

    sum_terms = 0

    for mask_T in range(1, 1 << m):
        T_primes = [i for i in range(m) if (mask_T & (1 << i))]
        B = []
        for elem in elements:
            exponents_e, mask_e, non_zero_e = elem
            if (mask_e & mask_T) == mask_T:
                B.append(elem)
        all_zero = True
        for elem in B:
            if (elem[2] & mask_T) != 0:
                all_zero = False
                break
        if all_zero:
            s_t = (pow(2, len(B), MOD) - 1) % MOD
        else:
            required_e = [required[i] for i in T_primes]
            exponents_list = []
            for elem in B:
                exps = elem[0]
                exps_T = [exps[i] for i in T_primes]
                exponents_list.append(exps_T)
            dp = defaultdict(int)
            initial_state = tuple([0] * len(T_primes))
            dp[initial_state] = 1
            for exps in exponents_list:
                new_dp = defaultdict(int)
                for state, cnt in dp.items():
                    new_state = list(state)
                    valid = True
                    for i in range(len(new_state)):
                        new_state[i] += exps[i]
                        if new_state[i] >= required_e[i]:
                            valid = False
                            break
                    if valid:
                        new_state_tuple = tuple(new_state)
                        new_dp[new_state_tuple] = (new_dp[new_state_tuple] + cnt) % MOD
                    new_dp[state] = (new_dp[state] + cnt) % MOD
                dp = new_dp
            total = (sum(dp.values()) - 1) % MOD
            s_t = total
        k = len(T_primes)
        term = (pow(-1, k, MOD) * s_t) % MOD
        sum_terms = (sum_terms + term) % MOD

    total_subsets = (pow(2, N, MOD) - 1) % MOD
    answer = (total_subsets + sum_terms) % MOD
    print(answer)

if __name__ == '__main__':
    main()
0