結果
問題 |
No.2313 Product of Subsequence (hard)
|
ユーザー |
![]() |
提出日時 | 2025-04-15 23:46:07 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,389 bytes |
コンパイル時間 | 158 ms |
コンパイル使用メモリ | 81,524 KB |
実行使用メモリ | 186,936 KB |
最終ジャッジ日時 | 2025-04-15 23:48:51 |
合計ジャッジ時間 | 12,266 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 10 TLE * 1 -- * 16 |
ソースコード
MOD = 998244353 def factorize(k): factors = {} i = 2 while i * i <= k: while k % i == 0: factors[i] = factors.get(i, 0) + 1 k //= i i += 1 if k > 1: factors[k] = 1 return factors def main(): import sys input = sys.stdin.read().split() ptr = 0 N, K = int(input[ptr]), int(input[ptr+1]) ptr += 2 A = list(map(int, input[ptr:ptr+N])) ptr += N if K == 1: print((pow(2, N, MOD) - 1) % MOD) return factors = factorize(K) primes = list(factors.keys()) m = len(primes) e = [factors[p] for p in primes] element_exponents = [] for num in A: ex = [] for p in primes: cnt = 0 tmp = num while tmp % p == 0: cnt += 1 tmp //= p ex.append(cnt) element_exponents.append(ex) inclusion_exclusion = 0 for mask in range(1, 1 << m): subset_indices = [] subset_e = [] for i in range(m): if (mask >> i) & 1: subset_indices.append(i) subset_e.append(e[i]) k = len(subset_indices) allowed = [] for idx in range(N): valid = True for i in subset_indices: if element_exponents[idx][i] >= e[i]: valid = False break if valid: allowed.append(idx) if not allowed: count = 0 else: state_size = 1 for ee in subset_e: state_size *= ee if state_size > 10**6: continue multipliers = [1] * len(subset_e) for i in range(1, len(subset_e)): multipliers[i] = multipliers[i-1] * subset_e[i-1] dp = [0] * state_size dp[0] = 1 for idx in allowed: ex = [element_exponents[idx][i] for i in subset_indices] new_dp = dp.copy() for state in range(state_size): if dp[state] == 0: continue current = [] rem = state for i in range(len(subset_e)): current.append(rem % subset_e[i]) rem = rem // subset_e[i] new_current = [current[i] + ex[i] for i in range(len(subset_e))] valid = True for i in range(len(subset_e)): if new_current[i] >= subset_e[i]: valid = False break if not valid: continue new_state = 0 for i in range(len(subset_e)): new_state += new_current[i] * multipliers[i] if new_state < state_size: new_dp[new_state] = (new_dp[new_state] + dp[state]) % MOD dp = new_dp total = (sum(dp) - 1) % MOD count = total sign = (-1) ** (k + 1) inclusion_exclusion = (inclusion_exclusion + sign * count) % MOD total_subsets = (pow(2, N, MOD) - 1) % MOD ans = (total_subsets - inclusion_exclusion) % MOD print(ans) if __name__ == '__main__': main()