結果
問題 |
No.2313 Product of Subsequence (hard)
|
ユーザー |
![]() |
提出日時 | 2025-06-12 13:22:16 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,602 bytes |
コンパイル時間 | 270 ms |
コンパイル使用メモリ | 82,688 KB |
実行使用メモリ | 154,648 KB |
最終ジャッジ日時 | 2025-06-12 13:25:09 |
合計ジャッジ時間 | 11,152 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 10 TLE * 1 -- * 16 |
ソースコード
import sys from collections import defaultdict MOD = 998244353 def factorize(k): factors = {} i = 2 while i * i <= k: if k % i == 0: cnt = 0 while k % i == 0: cnt += 1 k //= i factors[i] = cnt i += 1 if k > 1: factors[k] = 1 return factors def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr += 1 K = int(input[ptr]) ptr += 1 A = list(map(int, input[ptr:ptr+N])) ptr += N if K == 1: print((pow(2, N, MOD) - 1) % MOD) return factors = factorize(K) primes = list(factors.keys()) required = [factors[p] for p in primes] m = len(primes) # Precompute exponents for each element relevant = [] neutral = [] for a in A: ex = [] valid = False for p in primes: cnt = 0 x = a while x % p == 0: cnt += 1 x //= p ex.append(cnt) if cnt > 0: valid = True if valid: relevant.append(ex) else: neutral.append(a) M = len(neutral) R = len(relevant) if not relevant: print(0) return # Precompute powers of 2 max_pow = max(len(relevant), len(neutral)) + 1 pow2 = [1] * (max_pow) for i in range(1, max_pow): pow2[i] = (pow2[i-1] * 2) % MOD # Inclusion-exclusion over all subsets of primes total_valid = 0 from itertools import combinations for mask in range(1 << m): S = [i for i in range(m) if (mask >> i) & 1] sign = (-1) ** len(S) if not S: term = pow2[R] # all subsets of relevant elements total_valid = (total_valid + sign * term) % MOD continue # Compute allowed sums for each prime in S allowed = [required[i] - 1 for i in S] primes_in_S = [primes[i] for i in S] # Compute C: number of relevant elements with zero exponents for all primes in S C = 0 T_prime = [] for ex in relevant: ex_S = [ex[i] for i in S] if all(e == 0 for e in ex_S): C += 1 else: if all(e <= allowed[i] for i, e in enumerate(ex_S)) and any(e > 0 for e in ex_S): T_prime.append(ex_S) # Compute DP for T_prime m_prime = len(S) dp = defaultdict(int) initial = tuple([0]*m_prime) dp[initial] = 1 for ex_S in T_prime: new_dp = defaultdict(int) for state, cnt in dp.items(): new_state = list(state) valid = True for i in range(m_prime): new_state[i] += ex_S[i] if new_state[i] > allowed[i]: valid = False break if valid: new_state_tuple = tuple(new_state) new_dp[new_state_tuple] = (new_dp[new_state_tuple] + cnt) % MOD new_dp[tuple(state)] = (new_dp.get(tuple(state), 0) + cnt) % MOD dp = new_dp valid_subsets = sum(dp.values()) % MOD term = (valid_subsets * pow2[C]) % MOD total_valid = (total_valid + sign * term) % MOD # Ensure total_valid is non-negative total_valid = (total_valid % MOD + MOD) % MOD # Multiply by 2^M (neutral elements) result = (total_valid * pow2[M]) % MOD print(result) if __name__ == "__main__": main()