結果
| 問題 |
No.2313 Product of Subsequence (hard)
|
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-15 23:46:07 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,389 bytes |
| コンパイル時間 | 158 ms |
| コンパイル使用メモリ | 81,524 KB |
| 実行使用メモリ | 186,936 KB |
| 最終ジャッジ日時 | 2025-04-15 23:48:51 |
| 合計ジャッジ時間 | 12,266 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 10 TLE * 1 -- * 16 |
ソースコード
MOD = 998244353
def factorize(k):
factors = {}
i = 2
while i * i <= k:
while k % i == 0:
factors[i] = factors.get(i, 0) + 1
k //= i
i += 1
if k > 1:
factors[k] = 1
return factors
def main():
import sys
input = sys.stdin.read().split()
ptr = 0
N, K = int(input[ptr]), int(input[ptr+1])
ptr += 2
A = list(map(int, input[ptr:ptr+N]))
ptr += N
if K == 1:
print((pow(2, N, MOD) - 1) % MOD)
return
factors = factorize(K)
primes = list(factors.keys())
m = len(primes)
e = [factors[p] for p in primes]
element_exponents = []
for num in A:
ex = []
for p in primes:
cnt = 0
tmp = num
while tmp % p == 0:
cnt += 1
tmp //= p
ex.append(cnt)
element_exponents.append(ex)
inclusion_exclusion = 0
for mask in range(1, 1 << m):
subset_indices = []
subset_e = []
for i in range(m):
if (mask >> i) & 1:
subset_indices.append(i)
subset_e.append(e[i])
k = len(subset_indices)
allowed = []
for idx in range(N):
valid = True
for i in subset_indices:
if element_exponents[idx][i] >= e[i]:
valid = False
break
if valid:
allowed.append(idx)
if not allowed:
count = 0
else:
state_size = 1
for ee in subset_e:
state_size *= ee
if state_size > 10**6:
continue
multipliers = [1] * len(subset_e)
for i in range(1, len(subset_e)):
multipliers[i] = multipliers[i-1] * subset_e[i-1]
dp = [0] * state_size
dp[0] = 1
for idx in allowed:
ex = [element_exponents[idx][i] for i in subset_indices]
new_dp = dp.copy()
for state in range(state_size):
if dp[state] == 0:
continue
current = []
rem = state
for i in range(len(subset_e)):
current.append(rem % subset_e[i])
rem = rem // subset_e[i]
new_current = [current[i] + ex[i] for i in range(len(subset_e))]
valid = True
for i in range(len(subset_e)):
if new_current[i] >= subset_e[i]:
valid = False
break
if not valid:
continue
new_state = 0
for i in range(len(subset_e)):
new_state += new_current[i] * multipliers[i]
if new_state < state_size:
new_dp[new_state] = (new_dp[new_state] + dp[state]) % MOD
dp = new_dp
total = (sum(dp) - 1) % MOD
count = total
sign = (-1) ** (k + 1)
inclusion_exclusion = (inclusion_exclusion + sign * count) % MOD
total_subsets = (pow(2, N, MOD) - 1) % MOD
ans = (total_subsets - inclusion_exclusion) % MOD
print(ans)
if __name__ == '__main__':
main()
lam6er