結果
| 問題 |
No.1239 Multiplication -2
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-26 15:57:02 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 195 ms / 2,000 ms |
| コード長 | 3,276 bytes |
| コンパイル時間 | 388 ms |
| コンパイル使用メモリ | 81,920 KB |
| 実行使用メモリ | 138,272 KB |
| 最終ジャッジ日時 | 2025-03-26 15:57:31 |
| 合計ジャッジ時間 | 5,762 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 34 |
ソースコード
import sys
MOD = 998244353
def main():
input = sys.stdin.read().split()
N = int(input[0])
a = list(map(int, input[1:N+1]))
max_n = 2 * 10**5 + 10
pow2 = [1] * (max_n + 1)
for i in range(1, max_n + 1):
pow2[i] = (pow2[i-1] * 2) % MOD
inv2 = pow(2, MOD-2, MOD)
inv_pow2 = [1] * (max_n + 1)
inv_pow2[1] = inv2
for i in range(2, max_n + 1):
inv_pow2[i] = (inv_pow2[i-1] * inv2) % MOD
runs = []
current_run = []
start = 0
for i in range(N):
if a[i] != 0:
if not current_run:
start = i
current_run.append(a[i])
else:
if current_run:
runs.append( (start, i-1, current_run) )
current_run = []
if current_run:
runs.append( (start, N-1, current_run) )
total = 0
for (l, r, elements) in runs:
m = len(elements)
pos_list = []
for idx in range(m):
if abs(elements[idx]) == 2:
pos_list.append(idx + 1)
if not pos_list:
continue
prefix_parity = [0] * (m + 1)
cnt = 0
for i in range(m):
if elements[i] < 0:
cnt += 1
prefix_parity[i+1] = cnt % 2
sum0 = [0] * (m + 1)
sum1 = [0] * (m + 1)
run_start = l + 1
for s in range(1, m + 1):
if run_start >= 2:
val = pow2[s-1]
else:
if s == 1:
val = 2 % MOD
else:
val = pow2[s-1] % MOD
parity = prefix_parity[s-1]
sum0[s] = sum0[s-1]
sum1[s] = sum1[s-1]
if parity == 0:
sum0[s] = (sum0[s] + val) % MOD
else:
sum1[s] = (sum1[s] + val) % MOD
t = len(pos_list)
for i in range(t):
pos = pos_list[i]
prev_pos = pos_list[i-1] if i > 0 else None
next_pos = pos_list[i+1] if i < t -1 else None
L = 1
if prev_pos is not None:
L = prev_pos + 1
R = m
if next_pos is not None:
R = next_pos - 1
for e in range(pos, R + 1):
parity_e = prefix_parity[e]
desired_parity = 1 - parity_e
a_s = L
b_s = pos
sum_val = 0
if a_s > b_s:
continue
if desired_parity == 0:
sum_val = (sum0[b_s] - sum0[a_s - 1]) % MOD
else:
sum_val = (sum1[b_s] - sum1[a_s - 1]) % MOD
if sum_val < 0:
sum_val += MOD
if r < N - 1:
pow_term = e + 1
else:
if e < m:
pow_term = e + 1
else:
pow_term = e
inv = inv_pow2[pow_term]
contribution = (sum_val * inv) % MOD
total = (total + contribution) % MOD
print(total % MOD)
if __name__ == '__main__':
main()
lam6er