結果

問題 No.3119 A Little Cheat
ユーザー keigo kuwata
提出日時 2025-05-23 07:23:46
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
WA  
実行時間 -
コード長 4,649 bytes
コンパイル時間 728 ms
コンパイル使用メモリ 12,288 KB
実行使用メモリ 32,344 KB
最終ジャッジ日時 2025-05-23 07:24:21
合計ジャッジ時間 32,988 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 3
other AC * 3 WA * 46
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def solve():
    N, M = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))

    MOD = 998244353

    # Calculate SUM_0 = M^(N-1) * sum(max(0, M-A_i))
    sum_max_M_minus_Ai = 0
    for val_a in A:
        sum_max_M_minus_Ai = (sum_max_M_minus_Ai + max(0, M - val_a)) % MOD
    
    if N == 0: # Based on constraints 2 <= N, this path is not taken.
        term_MN_minus_1 = 0
    elif N == 1: 
        term_MN_minus_1 = 1
    else:
        term_MN_minus_1 = pow(M, N - 1, MOD)
    
    SUM_0 = (term_MN_minus_1 * sum_max_M_minus_Ai) % MOD

    # DP part to calculate S_N = N_all_le_0
    # State for dp_prev (representing dp[k-1])
    # c_prev1, c_prev2 are cut points. Intervals are [1,c1], (c1,c2], (c2,M].
    # v_prev1, v_prev2, v_prev3 are values in these intervals.
    
    # Initial state (dp[1]): For B_1, all counts are 1.
    # This means segment [1,M] has value 1. Segments (M,M] and (M,M] are empty or have value 0.
    c_prev1, c_prev2 = M, M 
    v_prev1, v_prev2, v_prev3 = 1, 0, 0 # v_prev2,v_prev3 are effectively 0 for empty segments

    # Loop for k = 2 to N (dp state for B_k based on B_{k-1})
    # k_loop_idx refers to the index of B being determined (e.g. B_2, B_3, ..., B_N)
    for k_loop_idx in range(2, N + 1): 
        # Calculate S_km1 (Sum for B_{k-1}) from (c_prev1, c_prev2, v_prev1, v_prev2, v_prev3)
        
        l1 = max(0, c_prev1) 
        l2 = max(0, c_prev2 - c_prev1)
        l3 = max(0, M - c_prev2)
        
        s_km1 = (v_prev1 * l1) % MOD
        s_km1 = (s_km1 + v_prev2 * l2) % MOD
        s_km1 = (s_km1 + v_prev3 * l3) % MOD
        s_km1 = (s_km1 + MOD) % MOD 

        # P_km1(X) function for B_{k-1} state
        def p_km1(val_x_func):
            if val_x_func <= 0: return 0
            res_func = 0
            
            len_seg1_func = max(0, min(val_x_func, c_prev1))
            res_func = (res_func + v_prev1 * len_seg1_func) % MOD
            
            if val_x_func > c_prev1:
                len_seg2_func = max(0, min(val_x_func, c_prev2) - c_prev1)
                res_func = (res_func + v_prev2 * len_seg2_func) % MOD
            
            if val_x_func > c_prev2:
                len_seg3_func = max(0, min(val_x_func, M) - c_prev2)
                res_func = (res_func + v_prev3 * len_seg3_func) % MOD
            
            return (res_func + MOD) % MOD

        # A_values for current DP transition refer to A_{k-1} and A_k from problem statement.
        # In 0-indexed A array: A[k_loop_idx - 2] and A[k_loop_idx - 1]
        prev_A_val_dp = A[k_loop_idx - 2] # This is A_{k-1} in problem notation
        cur_A_val_dp = A[k_loop_idx - 1]  # This is A_k in problem notation

        # Value for B_k if B_{k-1} <= A_{k-1} < B_k <= A_k (causes \Delta = 1)
        val_when_Bk_in_Aj_Ak_interval = (s_km1 - p_km1(prev_A_val_dp) + MOD) % MOD
        
        # Value for B_k if B_k <= A_k < B_{k-1} <= A_{k-1} (causes \Delta = 1)
        val_when_Bk_le_Ak_and_Bkprev_in_Ak_Akm1_interval = \
            (s_km1 - (p_km1(prev_A_val_dp) - p_km1(cur_A_val_dp) + MOD) % MOD + MOD) % MOD

        next_c1_dp, next_c2_dp = 0,0
        next_v1_dp, next_v2_dp, next_v3_dp = 0,0,0

        if prev_A_val_dp <= cur_A_val_dp:
            next_c1_dp = prev_A_val_dp
            next_c2_dp = cur_A_val_dp
            next_v1_dp = s_km1 # For B_k in [1, prev_A_val_dp]
            next_v2_dp = val_when_Bk_in_Aj_Ak_interval # For B_k in (prev_A_val_dp, cur_A_val_dp]
            next_v3_dp = s_km1 # For B_k in (cur_A_val_dp, M]
        else: # cur_A_val_dp < prev_A_val_dp
            next_c1_dp = cur_A_val_dp
            next_c2_dp = prev_A_val_dp
            next_v1_dp = val_when_Bk_le_Ak_and_Bkprev_in_Ak_Akm1_interval # For B_k in [1, cur_A_val_dp]
            next_v2_dp = s_km1 # For B_k in (cur_A_val_dp, prev_A_val_dp]
            next_v3_dp = s_km1 # For B_k in (prev_A_val_dp, M]
        
        c_prev1, c_prev2 = next_c1_dp, next_c2_dp
        v_prev1, v_prev2, v_prev3 = next_v1_dp, next_v2_dp, next_v3_dp
    
    # Calculate S_N (N_all_le_0) from the final state (c_prev1,..v_prev3) which represents dp[N]
    if N == 1: 
         s_N = M 
    else:
        l1_final = max(0, c_prev1)
        l2_final = max(0, c_prev2 - c_prev1)
        l3_final = max(0, M - c_prev2)
        
        s_N = (v_prev1 * l1_final) % MOD
        s_N = (s_N + v_prev2 * l2_final) % MOD
        s_N = (s_N + v_prev3 * l3_final) % MOD
        s_N = (s_N + MOD) % MOD
            
    sum_Y = (pow(M, N, MOD) - s_N + MOD) % MOD
    
    total_sum = (SUM_0 + sum_Y) % MOD
    
    sys.stdout.write(str(total_sum) + "\n")

solve()
0