結果

問題 No.2313 Product of Subsequence (hard)
ユーザー gew1fw
提出日時 2025-06-12 13:22:16
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,602 bytes
コンパイル時間 270 ms
コンパイル使用メモリ 82,688 KB
実行使用メモリ 154,648 KB
最終ジャッジ日時 2025-06-12 13:25:09
合計ジャッジ時間 11,152 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 10 TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

MOD = 998244353

def factorize(k):
    factors = {}
    i = 2
    while i * i <= k:
        if k % i == 0:
            cnt = 0
            while k % i == 0:
                cnt += 1
                k //= i
            factors[i] = cnt
        i += 1
    if k > 1:
        factors[k] = 1
    return factors

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    K = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N

    if K == 1:
        print((pow(2, N, MOD) - 1) % MOD)
        return

    factors = factorize(K)
    primes = list(factors.keys())
    required = [factors[p] for p in primes]
    m = len(primes)

    # Precompute exponents for each element
    relevant = []
    neutral = []
    for a in A:
        ex = []
        valid = False
        for p in primes:
            cnt = 0
            x = a
            while x % p == 0:
                cnt += 1
                x //= p
            ex.append(cnt)
            if cnt > 0:
                valid = True
        if valid:
            relevant.append(ex)
        else:
            neutral.append(a)
    M = len(neutral)
    R = len(relevant)

    if not relevant:
        print(0)
        return

    # Precompute powers of 2
    max_pow = max(len(relevant), len(neutral)) + 1
    pow2 = [1] * (max_pow)
    for i in range(1, max_pow):
        pow2[i] = (pow2[i-1] * 2) % MOD

    # Inclusion-exclusion over all subsets of primes
    total_valid = 0
    from itertools import combinations
    for mask in range(1 << m):
        S = [i for i in range(m) if (mask >> i) & 1]
        sign = (-1) ** len(S)
        if not S:
            term = pow2[R]  # all subsets of relevant elements
            total_valid = (total_valid + sign * term) % MOD
            continue

        # Compute allowed sums for each prime in S
        allowed = [required[i] - 1 for i in S]
        primes_in_S = [primes[i] for i in S]

        # Compute C: number of relevant elements with zero exponents for all primes in S
        C = 0
        T_prime = []
        for ex in relevant:
            ex_S = [ex[i] for i in S]
            if all(e == 0 for e in ex_S):
                C += 1
            else:
                if all(e <= allowed[i] for i, e in enumerate(ex_S)) and any(e > 0 for e in ex_S):
                    T_prime.append(ex_S)

        # Compute DP for T_prime
        m_prime = len(S)
        dp = defaultdict(int)
        initial = tuple([0]*m_prime)
        dp[initial] = 1

        for ex_S in T_prime:
            new_dp = defaultdict(int)
            for state, cnt in dp.items():
                new_state = list(state)
                valid = True
                for i in range(m_prime):
                    new_state[i] += ex_S[i]
                    if new_state[i] > allowed[i]:
                        valid = False
                        break
                if valid:
                    new_state_tuple = tuple(new_state)
                    new_dp[new_state_tuple] = (new_dp[new_state_tuple] + cnt) % MOD
                new_dp[tuple(state)] = (new_dp.get(tuple(state), 0) + cnt) % MOD
            dp = new_dp

        valid_subsets = sum(dp.values()) % MOD
        term = (valid_subsets * pow2[C]) % MOD
        total_valid = (total_valid + sign * term) % MOD

    # Ensure total_valid is non-negative
    total_valid = (total_valid % MOD + MOD) % MOD

    # Multiply by 2^M (neutral elements)
    result = (total_valid * pow2[M]) % MOD
    print(result)

if __name__ == "__main__":
    main()
0