結果
問題 |
No.2313 Product of Subsequence (hard)
|
ユーザー |
![]() |
提出日時 | 2025-03-31 17:50:53 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,367 bytes |
コンパイル時間 | 296 ms |
コンパイル使用メモリ | 82,160 KB |
実行使用メモリ | 268,356 KB |
最終ジャッジ日時 | 2025-03-31 17:52:07 |
合計ジャッジ時間 | 9,672 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 10 TLE * 1 -- * 16 |
ソースコード
MOD = 998244353 def factorize(n): factors = {} i = 2 while i * i <= n: if n % i == 0: cnt = 0 while n % i == 0: cnt +=1 n //=i factors[i] = cnt i +=1 if n >1: factors[n] =1 return factors def main(): import sys N, K = map(int, sys.stdin.readline().split()) A = list(map(int, sys.stdin.readline().split())) if K ==1: print((pow(2, N, MOD) -1) % MOD) return k_factors = factorize(K) primes = list(k_factors.items()) m = len(primes) core = [] M =0 for a in A: contributions = [] valid = False for (p, e) in primes: cnt =0 temp = a while temp % p ==0: cnt +=1 temp //=p contributions.append(cnt) if cnt >0: valid = True if valid: core.append(contributions) else: M +=1 sum_ans =0 for mask in range(1 << m): bits = bin(mask).count('1') T = [] Ts = [] e_list = [] for i in range(m): if (mask >> i) &1: T.append(i) Ts.append(i) e_list.append(primes[i][1]) T_set = set(Ts) filtered = [] zero_count =0 for contri in core: include = True for idx in Ts: if contri[idx] >= primes[idx][1]: include =False break if not include: continue is_zero = True new_contri = [] for idx in Ts: c = contri[idx] new_contri.append(c) if c !=0: is_zero = False if is_zero: zero_count +=1 else: filtered.append(new_contri) state = {} initial = tuple([0]*len(Ts)) state[initial] =1 for f in filtered: new_state = {} for s in state: cnt = state[s] new_s = list(s) valid = True for i in range(len(new_s)): new_s[i] += f[i] if new_s[i] >= e_list[i]: valid = False break if valid: new_s_tuple = tuple(new_s) if new_s_tuple in new_state: new_state[new_s_tuple] = (new_state[new_s_tuple] + cnt) % MOD else: new_state[new_s_tuple] = cnt % MOD if s in new_state: new_state[s] = (new_state[s] + cnt) % MOD else: new_state[s] = cnt % MOD state = new_state total =0 for v in state.values(): total = (total + v) % MOD total = total * pow(2, zero_count, MOD) % MOD sign = (-1)**(bits) if sign ==1: sum_ans = (sum_ans + total) % MOD else: sum_ans = (sum_ans - total) % MOD ans = sum_ans % MOD ans = ans * pow(2, M, MOD) % MOD print(ans) if __name__ == "__main__": main()