結果

問題 No.3119 A Little Cheat
ユーザー shibh308
提出日時 2025-04-20 18:57:59
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
TLE  
実行時間 -
コード長 2,562 bytes
コンパイル時間 140 ms
コンパイル使用メモリ 12,672 KB
実行使用メモリ 32,764 KB
最終ジャッジ日時 2025-04-20 18:59:25
合計ジャッジ時間 83,996 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 21 TLE * 28
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    MOD = 998244353

    N, M = map(int, input().split())
    A = list(map(int, input().split()))

    # 1) M^k の前計算
    powM = [1] * (N+1)
    M_mod = M % MOD
    for i in range(1, N+1):
        powM[i] = powM[i-1] * M_mod % MOD

    # 2) ベーススコア
    base = sum((M - a) for a in A) * powM[N-1] % MOD
    if N == 1:
        print(base)
        return

    # 3) DP 初期化(長さ1 の場合、B1 が小・中・大区間に入る数)
    t0, T0 = min(A[0], A[1]), max(A[0], A[1])
    dp_sm = t0
    dp_md = T0 - t0
    dp_lg = M - T0

    X = 0
    def overlap(a1, b1, a2, b2):
        return max(0, min(b1, b2) - max(a1, a2) + 1)

    # 4) 各 i で「はじめて改善させる」組合せ数を足しつつ DP を更新
    for i in range(N-1):
        ti, Ti = min(A[i], A[i+1]), max(A[i], A[i+1])
        sm_i = ti
        md_i = Ti - ti
        lg_i = M - Ti
        tail = powM[N-i-2] if N-i-2 >= 0 else 1

        # 貢献計算
        if A[i] < A[i+1]:
            X = (X + (dp_sm + dp_lg) % MOD * md_i % MOD * tail) % MOD
        else:
            X = (X + dp_md % MOD * (sm_i + lg_i) % MOD * tail) % MOD

        if i == N-2:
            break

        # 次のペアの閾値
        t_n, T_n = min(A[i+1], A[i+2]), max(A[i+1], A[i+2])
        sm_n, md_n, lg_n = t_n, T_n - t_n, M - T_n

        # 各状態から B_{i+1} を選ぶ「回避領域」を構築
        allowed = {}
        if A[i] < A[i+1]:
            allowed['sm'] = [(1, ti), (Ti+1, M)]
            allowed['md'] = [(1, M)]
            allowed['lg'] = allowed['sm']
        else:
            allowed['sm'] = [(1, M)]
            allowed['md'] = [(ti+1, Ti)]
            allowed['lg'] = allowed['sm']

        next_intervals = {
            'sm': [(1, t_n)],
            'md': [(t_n+1, T_n)],
            'lg': [(T_n+1, M)]
        }

        new_sm = new_md = new_lg = 0
        for reg, dpv in (('sm', dp_sm), ('md', dp_md), ('lg', dp_lg)):
            rngs = allowed[reg]
            for nxt_reg, ivals in next_intervals.items():
                cnt = 0
                for a1, b1 in rngs:
                    for a2, b2 in ivals:
                        cnt += overlap(a1, b1, a2, b2)
                if nxt_reg == 'sm':
                    new_sm += dpv * cnt
                elif nxt_reg == 'md':
                    new_md += dpv * cnt
                else:
                    new_lg += dpv * cnt

        dp_sm, dp_md, dp_lg = new_sm % MOD, new_md % MOD, new_lg % MOD

    print((base + X) % MOD)

if __name__ == "__main__":
    main()
0