結果

問題 No.2313 Product of Subsequence (hard)
ユーザー gew1fw
提出日時 2025-06-12 20:20:27
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,021 bytes
コンパイル時間 274 ms
コンパイル使用メモリ 81,784 KB
実行使用メモリ 285,252 KB
最終ジャッジ日時 2025-06-12 20:20:56
合計ジャッジ時間 9,171 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 9 TLE * 2 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from math import gcd
from collections import defaultdict

MOD = 998244353

def factorize(n):
    factors = {}
    while n % 2 == 0:
        factors[2] = factors.get(2, 0) + 1
        n //= 2
    i = 3
    while i * i <= n:
        while n % i == 0:
            factors[i] = factors.get(i, 0) + 1
            n //= i
        i += 2
    if n > 1:
        factors[n] = 1
    return factors

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    K = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N

    if K == 1:
        print(pow(2, N, MOD) - 1)
        return

    factors = factorize(K)
    primes = list(factors.keys())
    m = len(primes)
    e = {p: factors[p] for p in primes}

    prime_indices = {p: i for i, p in enumerate(primes)}

    def get_exponents(x):
        res = []
        for p in primes:
            cnt = 0
            while x % p == 0:
                cnt += 1
                x //= p
            res.append(cnt)
        return res

    a_exponents = []
    for x in A:
        a_exponents.append(get_exponents(x))

    total = pow(2, N, MOD) - 1

    from itertools import combinations, chain

    def powerset(s):
        return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

    prime_indices_list = list(range(m))

    res = 0

    for mask in range(1, 1 << m):
        bits = bin(mask).count('1')
        S = [primes[i] for i in range(m) if (mask >> i) & 1]
        S_e = [e[p] for p in S]

        filtered = []
        for ex in a_exponents:
            include = True
            for i, p in enumerate(S):
                if ex[prime_indices[p]] >= S_e[i]:
                    include = False
                    break
            if include:
                filtered.append([ex[prime_indices[p]] for p in S])

        if not filtered:
            h = 0
        else:
            m_S = len(S)
            max_sum = [s for s in S_e]
            state = defaultdict(int)
            state[tuple([0]*m_S)] = 1

            for exponents in filtered:
                exponents = exponents
                new_state = defaultdict(int)
                for current_sum, cnt in state.items():
                    new_sum = list(current_sum)
                    valid = True
                    for i in range(m_S):
                        new_sum[i] += exponents[i]
                        if new_sum[i] >= max_sum[i]:
                            valid = False
                            break
                    if valid:
                        new_sum_tuple = tuple(new_sum)
                        new_state[new_sum_tuple] = (new_state[new_sum_tuple] + cnt) % MOD
                for key, val in new_state.items():
                    state[key] = (state[key] + val) % MOD

            h = (sum(state.values()) - 1) % MOD

        sign = (-1) ** (bits)
        res = (res + sign * h) % MOD

    res = (res + total) % MOD
    print(res % MOD)

if __name__ == '__main__':
    main()
0