結果
問題 |
No.3119 A Little Cheat
|
ユーザー |
|
提出日時 | 2025-05-23 07:23:46 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
WA
|
実行時間 | - |
コード長 | 4,649 bytes |
コンパイル時間 | 728 ms |
コンパイル使用メモリ | 12,288 KB |
実行使用メモリ | 32,344 KB |
最終ジャッジ日時 | 2025-05-23 07:24:21 |
合計ジャッジ時間 | 32,988 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | WA * 3 |
other | AC * 3 WA * 46 |
ソースコード
import sys def solve(): N, M = map(int, sys.stdin.readline().split()) A = list(map(int, sys.stdin.readline().split())) MOD = 998244353 # Calculate SUM_0 = M^(N-1) * sum(max(0, M-A_i)) sum_max_M_minus_Ai = 0 for val_a in A: sum_max_M_minus_Ai = (sum_max_M_minus_Ai + max(0, M - val_a)) % MOD if N == 0: # Based on constraints 2 <= N, this path is not taken. term_MN_minus_1 = 0 elif N == 1: term_MN_minus_1 = 1 else: term_MN_minus_1 = pow(M, N - 1, MOD) SUM_0 = (term_MN_minus_1 * sum_max_M_minus_Ai) % MOD # DP part to calculate S_N = N_all_le_0 # State for dp_prev (representing dp[k-1]) # c_prev1, c_prev2 are cut points. Intervals are [1,c1], (c1,c2], (c2,M]. # v_prev1, v_prev2, v_prev3 are values in these intervals. # Initial state (dp[1]): For B_1, all counts are 1. # This means segment [1,M] has value 1. Segments (M,M] and (M,M] are empty or have value 0. c_prev1, c_prev2 = M, M v_prev1, v_prev2, v_prev3 = 1, 0, 0 # v_prev2,v_prev3 are effectively 0 for empty segments # Loop for k = 2 to N (dp state for B_k based on B_{k-1}) # k_loop_idx refers to the index of B being determined (e.g. B_2, B_3, ..., B_N) for k_loop_idx in range(2, N + 1): # Calculate S_km1 (Sum for B_{k-1}) from (c_prev1, c_prev2, v_prev1, v_prev2, v_prev3) l1 = max(0, c_prev1) l2 = max(0, c_prev2 - c_prev1) l3 = max(0, M - c_prev2) s_km1 = (v_prev1 * l1) % MOD s_km1 = (s_km1 + v_prev2 * l2) % MOD s_km1 = (s_km1 + v_prev3 * l3) % MOD s_km1 = (s_km1 + MOD) % MOD # P_km1(X) function for B_{k-1} state def p_km1(val_x_func): if val_x_func <= 0: return 0 res_func = 0 len_seg1_func = max(0, min(val_x_func, c_prev1)) res_func = (res_func + v_prev1 * len_seg1_func) % MOD if val_x_func > c_prev1: len_seg2_func = max(0, min(val_x_func, c_prev2) - c_prev1) res_func = (res_func + v_prev2 * len_seg2_func) % MOD if val_x_func > c_prev2: len_seg3_func = max(0, min(val_x_func, M) - c_prev2) res_func = (res_func + v_prev3 * len_seg3_func) % MOD return (res_func + MOD) % MOD # A_values for current DP transition refer to A_{k-1} and A_k from problem statement. # In 0-indexed A array: A[k_loop_idx - 2] and A[k_loop_idx - 1] prev_A_val_dp = A[k_loop_idx - 2] # This is A_{k-1} in problem notation cur_A_val_dp = A[k_loop_idx - 1] # This is A_k in problem notation # Value for B_k if B_{k-1} <= A_{k-1} < B_k <= A_k (causes \Delta = 1) val_when_Bk_in_Aj_Ak_interval = (s_km1 - p_km1(prev_A_val_dp) + MOD) % MOD # Value for B_k if B_k <= A_k < B_{k-1} <= A_{k-1} (causes \Delta = 1) val_when_Bk_le_Ak_and_Bkprev_in_Ak_Akm1_interval = \ (s_km1 - (p_km1(prev_A_val_dp) - p_km1(cur_A_val_dp) + MOD) % MOD + MOD) % MOD next_c1_dp, next_c2_dp = 0,0 next_v1_dp, next_v2_dp, next_v3_dp = 0,0,0 if prev_A_val_dp <= cur_A_val_dp: next_c1_dp = prev_A_val_dp next_c2_dp = cur_A_val_dp next_v1_dp = s_km1 # For B_k in [1, prev_A_val_dp] next_v2_dp = val_when_Bk_in_Aj_Ak_interval # For B_k in (prev_A_val_dp, cur_A_val_dp] next_v3_dp = s_km1 # For B_k in (cur_A_val_dp, M] else: # cur_A_val_dp < prev_A_val_dp next_c1_dp = cur_A_val_dp next_c2_dp = prev_A_val_dp next_v1_dp = val_when_Bk_le_Ak_and_Bkprev_in_Ak_Akm1_interval # For B_k in [1, cur_A_val_dp] next_v2_dp = s_km1 # For B_k in (cur_A_val_dp, prev_A_val_dp] next_v3_dp = s_km1 # For B_k in (prev_A_val_dp, M] c_prev1, c_prev2 = next_c1_dp, next_c2_dp v_prev1, v_prev2, v_prev3 = next_v1_dp, next_v2_dp, next_v3_dp # Calculate S_N (N_all_le_0) from the final state (c_prev1,..v_prev3) which represents dp[N] if N == 1: s_N = M else: l1_final = max(0, c_prev1) l2_final = max(0, c_prev2 - c_prev1) l3_final = max(0, M - c_prev2) s_N = (v_prev1 * l1_final) % MOD s_N = (s_N + v_prev2 * l2_final) % MOD s_N = (s_N + v_prev3 * l3_final) % MOD s_N = (s_N + MOD) % MOD sum_Y = (pow(M, N, MOD) - s_N + MOD) % MOD total_sum = (SUM_0 + sum_Y) % MOD sys.stdout.write(str(total_sum) + "\n") solve()