結果
問題 |
No.2313 Product of Subsequence (hard)
|
ユーザー |
![]() |
提出日時 | 2025-06-12 20:20:27 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,021 bytes |
コンパイル時間 | 274 ms |
コンパイル使用メモリ | 81,784 KB |
実行使用メモリ | 285,252 KB |
最終ジャッジ日時 | 2025-06-12 20:20:56 |
合計ジャッジ時間 | 9,171 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 9 TLE * 2 -- * 16 |
ソースコード
import sys from math import gcd from collections import defaultdict MOD = 998244353 def factorize(n): factors = {} while n % 2 == 0: factors[2] = factors.get(2, 0) + 1 n //= 2 i = 3 while i * i <= n: while n % i == 0: factors[i] = factors.get(i, 0) + 1 n //= i i += 2 if n > 1: factors[n] = 1 return factors def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr += 1 K = int(input[ptr]) ptr += 1 A = list(map(int, input[ptr:ptr+N])) ptr += N if K == 1: print(pow(2, N, MOD) - 1) return factors = factorize(K) primes = list(factors.keys()) m = len(primes) e = {p: factors[p] for p in primes} prime_indices = {p: i for i, p in enumerate(primes)} def get_exponents(x): res = [] for p in primes: cnt = 0 while x % p == 0: cnt += 1 x //= p res.append(cnt) return res a_exponents = [] for x in A: a_exponents.append(get_exponents(x)) total = pow(2, N, MOD) - 1 from itertools import combinations, chain def powerset(s): return chain.from_iterable(combinations(s, r) for r in range(len(s)+1)) prime_indices_list = list(range(m)) res = 0 for mask in range(1, 1 << m): bits = bin(mask).count('1') S = [primes[i] for i in range(m) if (mask >> i) & 1] S_e = [e[p] for p in S] filtered = [] for ex in a_exponents: include = True for i, p in enumerate(S): if ex[prime_indices[p]] >= S_e[i]: include = False break if include: filtered.append([ex[prime_indices[p]] for p in S]) if not filtered: h = 0 else: m_S = len(S) max_sum = [s for s in S_e] state = defaultdict(int) state[tuple([0]*m_S)] = 1 for exponents in filtered: exponents = exponents new_state = defaultdict(int) for current_sum, cnt in state.items(): new_sum = list(current_sum) valid = True for i in range(m_S): new_sum[i] += exponents[i] if new_sum[i] >= max_sum[i]: valid = False break if valid: new_sum_tuple = tuple(new_sum) new_state[new_sum_tuple] = (new_state[new_sum_tuple] + cnt) % MOD for key, val in new_state.items(): state[key] = (state[key] + val) % MOD h = (sum(state.values()) - 1) % MOD sign = (-1) ** (bits) res = (res + sign * h) % MOD res = (res + total) % MOD print(res % MOD) if __name__ == '__main__': main()