結果
| 問題 |
No.2313 Product of Subsequence (hard)
|
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 18:37:25 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,602 bytes |
| コンパイル時間 | 180 ms |
| コンパイル使用メモリ | 81,832 KB |
| 実行使用メモリ | 154,136 KB |
| 最終ジャッジ日時 | 2025-06-12 18:37:43 |
| 合計ジャッジ時間 | 11,365 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 10 TLE * 1 -- * 16 |
ソースコード
import sys
from collections import defaultdict
MOD = 998244353
def factorize(k):
factors = {}
i = 2
while i * i <= k:
if k % i == 0:
cnt = 0
while k % i == 0:
cnt += 1
k //= i
factors[i] = cnt
i += 1
if k > 1:
factors[k] = 1
return factors
def main():
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr])
ptr += 1
K = int(input[ptr])
ptr += 1
A = list(map(int, input[ptr:ptr+N]))
ptr += N
if K == 1:
print((pow(2, N, MOD) - 1) % MOD)
return
factors = factorize(K)
primes = list(factors.keys())
required = [factors[p] for p in primes]
m = len(primes)
# Precompute exponents for each element
relevant = []
neutral = []
for a in A:
ex = []
valid = False
for p in primes:
cnt = 0
x = a
while x % p == 0:
cnt += 1
x //= p
ex.append(cnt)
if cnt > 0:
valid = True
if valid:
relevant.append(ex)
else:
neutral.append(a)
M = len(neutral)
R = len(relevant)
if not relevant:
print(0)
return
# Precompute powers of 2
max_pow = max(len(relevant), len(neutral)) + 1
pow2 = [1] * (max_pow)
for i in range(1, max_pow):
pow2[i] = (pow2[i-1] * 2) % MOD
# Inclusion-exclusion over all subsets of primes
total_valid = 0
from itertools import combinations
for mask in range(1 << m):
S = [i for i in range(m) if (mask >> i) & 1]
sign = (-1) ** len(S)
if not S:
term = pow2[R] # all subsets of relevant elements
total_valid = (total_valid + sign * term) % MOD
continue
# Compute allowed sums for each prime in S
allowed = [required[i] - 1 for i in S]
primes_in_S = [primes[i] for i in S]
# Compute C: number of relevant elements with zero exponents for all primes in S
C = 0
T_prime = []
for ex in relevant:
ex_S = [ex[i] for i in S]
if all(e == 0 for e in ex_S):
C += 1
else:
if all(e <= allowed[i] for i, e in enumerate(ex_S)) and any(e > 0 for e in ex_S):
T_prime.append(ex_S)
# Compute DP for T_prime
m_prime = len(S)
dp = defaultdict(int)
initial = tuple([0]*m_prime)
dp[initial] = 1
for ex_S in T_prime:
new_dp = defaultdict(int)
for state, cnt in dp.items():
new_state = list(state)
valid = True
for i in range(m_prime):
new_state[i] += ex_S[i]
if new_state[i] > allowed[i]:
valid = False
break
if valid:
new_state_tuple = tuple(new_state)
new_dp[new_state_tuple] = (new_dp[new_state_tuple] + cnt) % MOD
new_dp[tuple(state)] = (new_dp.get(tuple(state), 0) + cnt) % MOD
dp = new_dp
valid_subsets = sum(dp.values()) % MOD
term = (valid_subsets * pow2[C]) % MOD
total_valid = (total_valid + sign * term) % MOD
# Ensure total_valid is non-negative
total_valid = (total_valid % MOD + MOD) % MOD
# Multiply by 2^M (neutral elements)
result = (total_valid * pow2[M]) % MOD
print(result)
if __name__ == "__main__":
main()
gew1fw