結果
問題 |
No.2313 Product of Subsequence (hard)
|
ユーザー |
![]() |
提出日時 | 2025-06-12 13:20:46 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,396 bytes |
コンパイル時間 | 437 ms |
コンパイル使用メモリ | 82,768 KB |
実行使用メモリ | 278,128 KB |
最終ジャッジ日時 | 2025-06-12 13:22:53 |
合計ジャッジ時間 | 9,978 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 10 TLE * 1 -- * 16 |
ソースコード
MOD = 998244353 def main(): import sys input = sys.stdin.read().split() ptr = 0 N, K = int(input[ptr]), int(input[ptr+1]) ptr += 2 A = list(map(int, input[ptr:ptr+N])) ptr += N if K == 1: print((pow(2, N, MOD) - 1) % MOD) return # Factorize K factors = factorize(K) primes = list(factors.keys()) exponents = [factors[p] for p in primes] m = len(primes) # Split into relevant and irrelevant elements relevant = [] irrelevant = [] for a in A: exps = [] for p in primes: cnt = 0 x = a while x % p == 0 and x != 0: cnt += 1 x = x // p exps.append(cnt) if all(e == 0 for e in exps): irrelevant.append(a) else: relevant.append(exps) R = len(relevant) M = len(irrelevant) total_invalid = 0 # Iterate over all non-empty subsets of primes for mask in range(1, 1 << m): S = [i for i in range(m) if (mask >> i) & 1] S_e = [exponents[i] for i in S] active = [] for exps in relevant: valid = True has_positive = False for i in S: e = exps[i] if e >= exponents[i]: valid = False break if e > 0: has_positive = True if valid and has_positive: s_exps = [exps[i] for i in S] active.append(s_exps) Q = 0 for exps in relevant: valid = True has_positive = False for i in S: e = exps[i] if e >= exponents[i]: valid = False break if e > 0: has_positive = True if valid and has_positive: continue all_zero = True for i in S: if exps[i] != 0: all_zero = False break if all_zero: Q += 1 # Compute X using DP dp = {} initial_state = tuple([0] * len(S)) dp[initial_state] = 1 for exps in active: new_dp = dp.copy() for state, cnt in dp.items(): new_state = list(state) valid = True for i in range(len(S)): new_state[i] += exps[i] if new_state[i] >= S_e[i]: valid = False break if valid: new_state_tuple = tuple(new_state) new_dp[new_state_tuple] = (new_dp.get(new_state_tuple, 0) + cnt) % MOD dp = new_dp X = sum(dp.values()) % MOD f_S = (X * pow(2, Q, MOD)) % MOD k = len(S) sign = (-1) ** (k + 1) total_invalid = (total_invalid + sign * f_S) % MOD valid_relevant = (pow(2, R, MOD) - total_invalid) % MOD ans = (valid_relevant * pow(2, M, MOD)) % MOD print(ans) def factorize(n): factors = {} i = 2 while i * i <= n: while n % i == 0: factors[i] = factors.get(i, 0) + 1 n = n // i i += 1 if n > 1: factors[n] = 1 return factors if __name__ == '__main__': main()