結果

問題 No.2313 Product of Subsequence (hard)
ユーザー gew1fw
提出日時 2025-06-12 13:20:46
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,396 bytes
コンパイル時間 437 ms
コンパイル使用メモリ 82,768 KB
実行使用メモリ 278,128 KB
最終ジャッジ日時 2025-06-12 13:22:53
合計ジャッジ時間 9,978 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 10 TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N, K = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N

    if K == 1:
        print((pow(2, N, MOD) - 1) % MOD)
        return

    # Factorize K
    factors = factorize(K)
    primes = list(factors.keys())
    exponents = [factors[p] for p in primes]
    m = len(primes)

    # Split into relevant and irrelevant elements
    relevant = []
    irrelevant = []
    for a in A:
        exps = []
        for p in primes:
            cnt = 0
            x = a
            while x % p == 0 and x != 0:
                cnt += 1
                x = x // p
            exps.append(cnt)
        if all(e == 0 for e in exps):
            irrelevant.append(a)
        else:
            relevant.append(exps)

    R = len(relevant)
    M = len(irrelevant)

    total_invalid = 0

    # Iterate over all non-empty subsets of primes
    for mask in range(1, 1 << m):
        S = [i for i in range(m) if (mask >> i) & 1]
        S_e = [exponents[i] for i in S]
        active = []
        for exps in relevant:
            valid = True
            has_positive = False
            for i in S:
                e = exps[i]
                if e >= exponents[i]:
                    valid = False
                    break
                if e > 0:
                    has_positive = True
            if valid and has_positive:
                s_exps = [exps[i] for i in S]
                active.append(s_exps)

        Q = 0
        for exps in relevant:
            valid = True
            has_positive = False
            for i in S:
                e = exps[i]
                if e >= exponents[i]:
                    valid = False
                    break
                if e > 0:
                    has_positive = True
            if valid and has_positive:
                continue
            all_zero = True
            for i in S:
                if exps[i] != 0:
                    all_zero = False
                    break
            if all_zero:
                Q += 1

        # Compute X using DP
        dp = {}
        initial_state = tuple([0] * len(S))
        dp[initial_state] = 1
        for exps in active:
            new_dp = dp.copy()
            for state, cnt in dp.items():
                new_state = list(state)
                valid = True
                for i in range(len(S)):
                    new_state[i] += exps[i]
                    if new_state[i] >= S_e[i]:
                        valid = False
                        break
                if valid:
                    new_state_tuple = tuple(new_state)
                    new_dp[new_state_tuple] = (new_dp.get(new_state_tuple, 0) + cnt) % MOD
            dp = new_dp

        X = sum(dp.values()) % MOD
        f_S = (X * pow(2, Q, MOD)) % MOD
        k = len(S)
        sign = (-1) ** (k + 1)
        total_invalid = (total_invalid + sign * f_S) % MOD

    valid_relevant = (pow(2, R, MOD) - total_invalid) % MOD
    ans = (valid_relevant * pow(2, M, MOD)) % MOD
    print(ans)

def factorize(n):
    factors = {}
    i = 2
    while i * i <= n:
        while n % i == 0:
            factors[i] = factors.get(i, 0) + 1
            n = n // i
        i += 1
    if n > 1:
        factors[n] = 1
    return factors

if __name__ == '__main__':
    main()
0