結果

問題 No.2313 Product of Subsequence (hard)
ユーザー gew1fw
提出日時 2025-06-12 18:39:49
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,691 bytes
コンパイル時間 258 ms
コンパイル使用メモリ 82,544 KB
実行使用メモリ 289,264 KB
最終ジャッジ日時 2025-06-12 18:40:01
合計ジャッジ時間 7,336 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 5 TLE * 1 -- * 21
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import defaultdict

MOD = 998244353

def factorize(k):
    factors = {}
    i = 2
    while i * i <= k:
        while k % i == 0:
            factors[i] = factors.get(i, 0) + 1
            k = k // i
        i += 1
    if k > 1:
        factors[k] = 1
    return factors

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, K = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N

    if K == 1:
        print((pow(2, N, MOD) - 1) % MOD)
        return

    factors = factorize(K)
    primes = list(factors.keys())
    m = len(primes)
    if m == 0:
        print(0)
        return

    all_exponents = []
    for a in A:
        exp = {}
        for p in primes:
            e = 0
            x = a
            while x % p == 0:
                e += 1
                x = x // p
            exp[p] = e
        all_exponents.append(exp)

    total = (pow(2, N, MOD) - 1) % MOD

    bad = 0
    for mask in range(1, 1 << m):
        bits = bin(mask).count('1')
        subset = []
        required = []
        for i in range(m):
            if mask & (1 << i):
                subset.append(primes[i])
                required.append(factors[primes[i]])

        filtered = []
        for exp in all_exponents:
            valid = True
            for p, e in zip(subset, required):
                if exp[p] >= e:
                    valid = False
                    break
            if valid:
                filtered.append([exp[p] for p in subset])

        if not filtered:
            continue

        k = len(subset)
        dp = defaultdict(int)
        initial_state = tuple([0] * k)
        dp[initial_state] = 1

        for exponents in filtered:
            new_dp = defaultdict(int)
            for state in dp:
                current = list(state)
                new_state = []
                valid = True
                for i in range(k):
                    s = current[i] + exponents[i]
                    if s >= required[i]:
                        valid = False
                        break
                    new_state.append(s)
                if not valid:
                    continue
                new_state = tuple(new_state)
                new_dp[new_state] = (new_dp[new_state] + dp[state]) % MOD
            for state in new_dp:
                dp[state] = (dp[state] + new_dp[state]) % MOD

        sum_dp = sum(dp.values()) % MOD
        count = (sum_dp - 1) % MOD
        sign = (-1) ** (bits + 1)
        bad = (bad + sign * count) % MOD

    answer = (total - bad) % MOD
    print(answer)

if __name__ == '__main__':
    main()
0