結果
問題 |
No.2313 Product of Subsequence (hard)
|
ユーザー |
![]() |
提出日時 | 2025-06-12 18:37:27 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,691 bytes |
コンパイル時間 | 255 ms |
コンパイル使用メモリ | 82,432 KB |
実行使用メモリ | 80,672 KB |
最終ジャッジ日時 | 2025-06-12 18:37:49 |
合計ジャッジ時間 | 7,860 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 5 TLE * 1 -- * 21 |
ソースコード
import sys from sys import stdin from collections import defaultdict MOD = 998244353 def factorize(k): factors = {} i = 2 while i * i <= k: while k % i == 0: factors[i] = factors.get(i, 0) + 1 k = k // i i += 1 if k > 1: factors[k] = 1 return factors def main(): input = sys.stdin.read().split() ptr = 0 N, K = int(input[ptr]), int(input[ptr+1]) ptr += 2 A = list(map(int, input[ptr:ptr+N])) ptr += N if K == 1: print((pow(2, N, MOD) - 1) % MOD) return factors = factorize(K) primes = list(factors.keys()) m = len(primes) if m == 0: print(0) return all_exponents = [] for a in A: exp = {} for p in primes: e = 0 x = a while x % p == 0: e += 1 x = x // p exp[p] = e all_exponents.append(exp) total = (pow(2, N, MOD) - 1) % MOD bad = 0 for mask in range(1, 1 << m): bits = bin(mask).count('1') subset = [] required = [] for i in range(m): if mask & (1 << i): subset.append(primes[i]) required.append(factors[primes[i]]) filtered = [] for exp in all_exponents: valid = True for p, e in zip(subset, required): if exp[p] >= e: valid = False break if valid: filtered.append([exp[p] for p in subset]) if not filtered: continue k = len(subset) dp = defaultdict(int) initial_state = tuple([0] * k) dp[initial_state] = 1 for exponents in filtered: new_dp = defaultdict(int) for state in dp: current = list(state) new_state = [] valid = True for i in range(k): s = current[i] + exponents[i] if s >= required[i]: valid = False break new_state.append(s) if not valid: continue new_state = tuple(new_state) new_dp[new_state] = (new_dp[new_state] + dp[state]) % MOD for state in new_dp: dp[state] = (dp[state] + new_dp[state]) % MOD sum_dp = sum(dp.values()) % MOD count = (sum_dp - 1) % MOD sign = (-1) ** (bits + 1) bad = (bad + sign * count) % MOD answer = (total - bad) % MOD print(answer) if __name__ == '__main__': main()