結果
| 問題 |
No.2313 Product of Subsequence (hard)
|
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:50:53 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,367 bytes |
| コンパイル時間 | 296 ms |
| コンパイル使用メモリ | 82,160 KB |
| 実行使用メモリ | 268,356 KB |
| 最終ジャッジ日時 | 2025-03-31 17:52:07 |
| 合計ジャッジ時間 | 9,672 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 10 TLE * 1 -- * 16 |
ソースコード
MOD = 998244353
def factorize(n):
factors = {}
i = 2
while i * i <= n:
if n % i == 0:
cnt = 0
while n % i == 0:
cnt +=1
n //=i
factors[i] = cnt
i +=1
if n >1:
factors[n] =1
return factors
def main():
import sys
N, K = map(int, sys.stdin.readline().split())
A = list(map(int, sys.stdin.readline().split()))
if K ==1:
print((pow(2, N, MOD) -1) % MOD)
return
k_factors = factorize(K)
primes = list(k_factors.items())
m = len(primes)
core = []
M =0
for a in A:
contributions = []
valid = False
for (p, e) in primes:
cnt =0
temp = a
while temp % p ==0:
cnt +=1
temp //=p
contributions.append(cnt)
if cnt >0:
valid = True
if valid:
core.append(contributions)
else:
M +=1
sum_ans =0
for mask in range(1 << m):
bits = bin(mask).count('1')
T = []
Ts = []
e_list = []
for i in range(m):
if (mask >> i) &1:
T.append(i)
Ts.append(i)
e_list.append(primes[i][1])
T_set = set(Ts)
filtered = []
zero_count =0
for contri in core:
include = True
for idx in Ts:
if contri[idx] >= primes[idx][1]:
include =False
break
if not include:
continue
is_zero = True
new_contri = []
for idx in Ts:
c = contri[idx]
new_contri.append(c)
if c !=0:
is_zero = False
if is_zero:
zero_count +=1
else:
filtered.append(new_contri)
state = {}
initial = tuple([0]*len(Ts))
state[initial] =1
for f in filtered:
new_state = {}
for s in state:
cnt = state[s]
new_s = list(s)
valid = True
for i in range(len(new_s)):
new_s[i] += f[i]
if new_s[i] >= e_list[i]:
valid = False
break
if valid:
new_s_tuple = tuple(new_s)
if new_s_tuple in new_state:
new_state[new_s_tuple] = (new_state[new_s_tuple] + cnt) % MOD
else:
new_state[new_s_tuple] = cnt % MOD
if s in new_state:
new_state[s] = (new_state[s] + cnt) % MOD
else:
new_state[s] = cnt % MOD
state = new_state
total =0
for v in state.values():
total = (total + v) % MOD
total = total * pow(2, zero_count, MOD) % MOD
sign = (-1)**(bits)
if sign ==1:
sum_ans = (sum_ans + total) % MOD
else:
sum_ans = (sum_ans - total) % MOD
ans = sum_ans % MOD
ans = ans * pow(2, M, MOD) % MOD
print(ans)
if __name__ == "__main__":
main()
lam6er