結果
| 問題 |
No.3119 A Little Cheat
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-05-23 07:23:46 |
| 言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 4,649 bytes |
| コンパイル時間 | 728 ms |
| コンパイル使用メモリ | 12,288 KB |
| 実行使用メモリ | 32,344 KB |
| 最終ジャッジ日時 | 2025-05-23 07:24:21 |
| 合計ジャッジ時間 | 32,988 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | WA * 3 |
| other | AC * 3 WA * 46 |
ソースコード
import sys
def solve():
N, M = map(int, sys.stdin.readline().split())
A = list(map(int, sys.stdin.readline().split()))
MOD = 998244353
# Calculate SUM_0 = M^(N-1) * sum(max(0, M-A_i))
sum_max_M_minus_Ai = 0
for val_a in A:
sum_max_M_minus_Ai = (sum_max_M_minus_Ai + max(0, M - val_a)) % MOD
if N == 0: # Based on constraints 2 <= N, this path is not taken.
term_MN_minus_1 = 0
elif N == 1:
term_MN_minus_1 = 1
else:
term_MN_minus_1 = pow(M, N - 1, MOD)
SUM_0 = (term_MN_minus_1 * sum_max_M_minus_Ai) % MOD
# DP part to calculate S_N = N_all_le_0
# State for dp_prev (representing dp[k-1])
# c_prev1, c_prev2 are cut points. Intervals are [1,c1], (c1,c2], (c2,M].
# v_prev1, v_prev2, v_prev3 are values in these intervals.
# Initial state (dp[1]): For B_1, all counts are 1.
# This means segment [1,M] has value 1. Segments (M,M] and (M,M] are empty or have value 0.
c_prev1, c_prev2 = M, M
v_prev1, v_prev2, v_prev3 = 1, 0, 0 # v_prev2,v_prev3 are effectively 0 for empty segments
# Loop for k = 2 to N (dp state for B_k based on B_{k-1})
# k_loop_idx refers to the index of B being determined (e.g. B_2, B_3, ..., B_N)
for k_loop_idx in range(2, N + 1):
# Calculate S_km1 (Sum for B_{k-1}) from (c_prev1, c_prev2, v_prev1, v_prev2, v_prev3)
l1 = max(0, c_prev1)
l2 = max(0, c_prev2 - c_prev1)
l3 = max(0, M - c_prev2)
s_km1 = (v_prev1 * l1) % MOD
s_km1 = (s_km1 + v_prev2 * l2) % MOD
s_km1 = (s_km1 + v_prev3 * l3) % MOD
s_km1 = (s_km1 + MOD) % MOD
# P_km1(X) function for B_{k-1} state
def p_km1(val_x_func):
if val_x_func <= 0: return 0
res_func = 0
len_seg1_func = max(0, min(val_x_func, c_prev1))
res_func = (res_func + v_prev1 * len_seg1_func) % MOD
if val_x_func > c_prev1:
len_seg2_func = max(0, min(val_x_func, c_prev2) - c_prev1)
res_func = (res_func + v_prev2 * len_seg2_func) % MOD
if val_x_func > c_prev2:
len_seg3_func = max(0, min(val_x_func, M) - c_prev2)
res_func = (res_func + v_prev3 * len_seg3_func) % MOD
return (res_func + MOD) % MOD
# A_values for current DP transition refer to A_{k-1} and A_k from problem statement.
# In 0-indexed A array: A[k_loop_idx - 2] and A[k_loop_idx - 1]
prev_A_val_dp = A[k_loop_idx - 2] # This is A_{k-1} in problem notation
cur_A_val_dp = A[k_loop_idx - 1] # This is A_k in problem notation
# Value for B_k if B_{k-1} <= A_{k-1} < B_k <= A_k (causes \Delta = 1)
val_when_Bk_in_Aj_Ak_interval = (s_km1 - p_km1(prev_A_val_dp) + MOD) % MOD
# Value for B_k if B_k <= A_k < B_{k-1} <= A_{k-1} (causes \Delta = 1)
val_when_Bk_le_Ak_and_Bkprev_in_Ak_Akm1_interval = \
(s_km1 - (p_km1(prev_A_val_dp) - p_km1(cur_A_val_dp) + MOD) % MOD + MOD) % MOD
next_c1_dp, next_c2_dp = 0,0
next_v1_dp, next_v2_dp, next_v3_dp = 0,0,0
if prev_A_val_dp <= cur_A_val_dp:
next_c1_dp = prev_A_val_dp
next_c2_dp = cur_A_val_dp
next_v1_dp = s_km1 # For B_k in [1, prev_A_val_dp]
next_v2_dp = val_when_Bk_in_Aj_Ak_interval # For B_k in (prev_A_val_dp, cur_A_val_dp]
next_v3_dp = s_km1 # For B_k in (cur_A_val_dp, M]
else: # cur_A_val_dp < prev_A_val_dp
next_c1_dp = cur_A_val_dp
next_c2_dp = prev_A_val_dp
next_v1_dp = val_when_Bk_le_Ak_and_Bkprev_in_Ak_Akm1_interval # For B_k in [1, cur_A_val_dp]
next_v2_dp = s_km1 # For B_k in (cur_A_val_dp, prev_A_val_dp]
next_v3_dp = s_km1 # For B_k in (prev_A_val_dp, M]
c_prev1, c_prev2 = next_c1_dp, next_c2_dp
v_prev1, v_prev2, v_prev3 = next_v1_dp, next_v2_dp, next_v3_dp
# Calculate S_N (N_all_le_0) from the final state (c_prev1,..v_prev3) which represents dp[N]
if N == 1:
s_N = M
else:
l1_final = max(0, c_prev1)
l2_final = max(0, c_prev2 - c_prev1)
l3_final = max(0, M - c_prev2)
s_N = (v_prev1 * l1_final) % MOD
s_N = (s_N + v_prev2 * l2_final) % MOD
s_N = (s_N + v_prev3 * l3_final) % MOD
s_N = (s_N + MOD) % MOD
sum_Y = (pow(M, N, MOD) - s_N + MOD) % MOD
total_sum = (SUM_0 + sum_Y) % MOD
sys.stdout.write(str(total_sum) + "\n")
solve()