結果
| 問題 |
No.2313 Product of Subsequence (hard)
|
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 16:29:51 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,285 bytes |
| コンパイル時間 | 429 ms |
| コンパイル使用メモリ | 81,924 KB |
| 実行使用メモリ | 280,492 KB |
| 最終ジャッジ日時 | 2025-04-16 16:31:31 |
| 合計ジャッジ時間 | 11,916 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 10 TLE * 1 -- * 16 |
ソースコード
MOD = 998244353
def main():
import sys
from sys import stdin
from collections import defaultdict
N, K = map(int, stdin.readline().split())
A = list(map(int, stdin.readline().split()))
if K == 1:
print((pow(2, N, MOD) - 1) % MOD)
return
def factorize(k):
factors = {}
i = 2
while i * i <= k:
while k % i == 0:
factors[i] = factors.get(i, 0) + 1
k = k // i
i += 1
if k > 1:
factors[k] = 1
return factors
k_factors = factorize(K)
if not k_factors:
print((pow(2, N, MOD) - 1) % MOD)
return
primes = sorted(k_factors.keys())
m = len(primes)
required = [k_factors[p] for p in primes]
elements = []
for a in A:
exponents = []
a_remaining = a
for p in primes:
e = 0
while a_remaining % p == 0:
e += 1
a_remaining = a_remaining // p
exponents.append(e)
mask = 0
for i in range(m):
if exponents[i] <= required[i] - 1:
mask |= (1 << i)
non_zero_mask = 0
for i in range(m):
if exponents[i] > 0:
non_zero_mask |= (1 << i)
elements.append((exponents, mask, non_zero_mask))
sum_terms = 0
for mask_T in range(1, 1 << m):
T_primes = [i for i in range(m) if (mask_T & (1 << i))]
B = []
for elem in elements:
exponents_e, mask_e, non_zero_e = elem
if (mask_e & mask_T) == mask_T:
B.append(elem)
all_zero = True
for elem in B:
if (elem[2] & mask_T) != 0:
all_zero = False
break
if all_zero:
s_t = (pow(2, len(B), MOD) - 1) % MOD
else:
required_e = [required[i] for i in T_primes]
exponents_list = []
for elem in B:
exps = elem[0]
exps_T = [exps[i] for i in T_primes]
exponents_list.append(exps_T)
dp = defaultdict(int)
initial_state = tuple([0] * len(T_primes))
dp[initial_state] = 1
for exps in exponents_list:
new_dp = defaultdict(int)
for state, cnt in dp.items():
new_state = list(state)
valid = True
for i in range(len(new_state)):
new_state[i] += exps[i]
if new_state[i] >= required_e[i]:
valid = False
break
if valid:
new_state_tuple = tuple(new_state)
new_dp[new_state_tuple] = (new_dp[new_state_tuple] + cnt) % MOD
new_dp[state] = (new_dp[state] + cnt) % MOD
dp = new_dp
total = (sum(dp.values()) - 1) % MOD
s_t = total
k = len(T_primes)
term = (pow(-1, k, MOD) * s_t) % MOD
sum_terms = (sum_terms + term) % MOD
total_subsets = (pow(2, N, MOD) - 1) % MOD
answer = (total_subsets + sum_terms) % MOD
print(answer)
if __name__ == '__main__':
main()
lam6er