結果
問題 |
No.2313 Product of Subsequence (hard)
|
ユーザー |
![]() |
提出日時 | 2025-04-15 23:42:30 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,285 bytes |
コンパイル時間 | 132 ms |
コンパイル使用メモリ | 82,644 KB |
実行使用メモリ | 262,548 KB |
最終ジャッジ日時 | 2025-04-15 23:44:50 |
合計ジャッジ時間 | 12,550 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 10 TLE * 1 -- * 16 |
ソースコード
MOD = 998244353 def main(): import sys from sys import stdin from collections import defaultdict N, K = map(int, stdin.readline().split()) A = list(map(int, stdin.readline().split())) if K == 1: print((pow(2, N, MOD) - 1) % MOD) return def factorize(k): factors = {} i = 2 while i * i <= k: while k % i == 0: factors[i] = factors.get(i, 0) + 1 k = k // i i += 1 if k > 1: factors[k] = 1 return factors k_factors = factorize(K) if not k_factors: print((pow(2, N, MOD) - 1) % MOD) return primes = sorted(k_factors.keys()) m = len(primes) required = [k_factors[p] for p in primes] elements = [] for a in A: exponents = [] a_remaining = a for p in primes: e = 0 while a_remaining % p == 0: e += 1 a_remaining = a_remaining // p exponents.append(e) mask = 0 for i in range(m): if exponents[i] <= required[i] - 1: mask |= (1 << i) non_zero_mask = 0 for i in range(m): if exponents[i] > 0: non_zero_mask |= (1 << i) elements.append((exponents, mask, non_zero_mask)) sum_terms = 0 for mask_T in range(1, 1 << m): T_primes = [i for i in range(m) if (mask_T & (1 << i))] B = [] for elem in elements: exponents_e, mask_e, non_zero_e = elem if (mask_e & mask_T) == mask_T: B.append(elem) all_zero = True for elem in B: if (elem[2] & mask_T) != 0: all_zero = False break if all_zero: s_t = (pow(2, len(B), MOD) - 1) % MOD else: required_e = [required[i] for i in T_primes] exponents_list = [] for elem in B: exps = elem[0] exps_T = [exps[i] for i in T_primes] exponents_list.append(exps_T) dp = defaultdict(int) initial_state = tuple([0] * len(T_primes)) dp[initial_state] = 1 for exps in exponents_list: new_dp = defaultdict(int) for state, cnt in dp.items(): new_state = list(state) valid = True for i in range(len(new_state)): new_state[i] += exps[i] if new_state[i] >= required_e[i]: valid = False break if valid: new_state_tuple = tuple(new_state) new_dp[new_state_tuple] = (new_dp[new_state_tuple] + cnt) % MOD new_dp[state] = (new_dp[state] + cnt) % MOD dp = new_dp total = (sum(dp.values()) - 1) % MOD s_t = total k = len(T_primes) term = (pow(-1, k, MOD) * s_t) % MOD sum_terms = (sum_terms + term) % MOD total_subsets = (pow(2, N, MOD) - 1) % MOD answer = (total_subsets + sum_terms) % MOD print(answer) if __name__ == '__main__': main()