結果

問題 No.3119 A Little Cheat
ユーザー 👑 tatyam
提出日時 2025-04-20 14:09:09
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 53 ms / 2,000 ms
コード長 3,857 bytes
コンパイル時間 3,649 ms
コンパイル使用メモリ 281,984 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-04-20 14:09:18
合計ジャッジ時間 8,083 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 49
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/modint>
using namespace std;
using mint = atcoder::modint998244353;

// swap による増分だけを数えれば良いのか、かしけ〜
int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N;
    long long M;
    cin >> N >> M;
    vector<long long> A(N+2);
    for (int i = 1; i <= N; i++) cin >> A[i];
    A[N+1] = M;  // sentinel for A[N+1]

    // sum of base scores without any swap
    mint base_sum = 0;
    if (N >= 1) {
        // sum_{i=1..N} (M - A[i]) * M^{N-1}
        mint powM = mint(M).pow(max(0, N-1));
        mint tot = 0;
        for (int i = 1; i <= N; i++) tot += mint(M - A[i]);
        base_sum = powM * tot;
    }

    // DP to count sequences by max swap-gain = 1 or 2
    mint cnt1 = 0, cnt2 = 0;
    if (N >= 2) {
        // dp[mx][prev_t1][prev_t2]
        static mint dp_cur[3][2][2], dp_nxt[3][2][2];
        // init for i=1 (no gain yet)
        long long t1 = A[1], t2 = A[2];
        for (int mx = 0; mx < 3; mx++) for (int x = 0; x < 2; x++) for (int y = 0; y < 2; y++) dp_cur[mx][x][y] = 0;
        for (int x = 0; x < 2; x++) {
            for (int y = 0; y < 2; y++) {
                long long lb = 1, ub = M;
                if (x) lb = max(lb, t1 + 1);
                else ub = min(ub, t1);
                if (y) lb = max(lb, t2 + 1);
                else ub = min(ub, t2);
                long long cnt = 0;
                if (lb <= ub) cnt = ub - lb + 1;
                dp_cur[0][x][y] = mint(cnt);
            }
        }
        // transitions for i=2..N
        for (int i = 2; i <= N; i++) {
            // clear next
            for (int mx = 0; mx < 3; mx++)
                for (int x = 0; x < 2; x++)
                    for (int y = 0; y < 2; y++)
                        dp_nxt[mx][x][y] = 0;
            long long t_prev = A[i-1], t_cur = A[i], t_next = A[i+1];
            // precompute counts for B_i categories
            mint cat[2][2][2];
            for (int b1 = 0; b1 < 2; b1++) for (int b2 = 0; b2 < 2; b2++) for (int b3 = 0; b3 < 2; b3++) {
                long long lb = 1, ub = M;
                if (b1) lb = max(lb, t_prev + 1);
                else ub = min(ub, t_prev);
                if (b2) lb = max(lb, t_cur + 1);
                else ub = min(ub, t_cur);
                if (b3) lb = max(lb, t_next + 1);
                else ub = min(ub, t_next);
                long long c = 0;
                if (lb <= ub) c = ub - lb + 1;
                cat[b1][b2][b3] = mint(c);
            }
            // DP step
            for (int mx = 0; mx < 3; mx++) {
                for (int p1 = 0; p1 < 2; p1++) for (int p2 = 0; p2 < 2; p2++) {
                    mint ways = dp_cur[mx][p1][p2];
                    if (ways.val() == 0) continue;
                    for (int b1 = 0; b1 < 2; b1++) for (int b2 = 0; b2 < 2; b2++) for (int b3 = 0; b3 < 2; b3++) {
                        mint c = cat[b1][b2][b3];
                        if (c.val() == 0) continue;
                        int gain = b1 + p2 - (p1 + b2);
                        int gcl = gain > 0 ? gain : 0;
                        int nmx = max(mx, gcl);
                        if (nmx > 2) nmx = 2;
                        dp_nxt[nmx][b2][b3] += ways * c;
                    }
                }
            }
            // swap
            for (int mx = 0; mx < 3; mx++)
                for (int x = 0; x < 2; x++) for (int y = 0; y < 2; y++)
                    dp_cur[mx][x][y] = dp_nxt[mx][x][y];
        }
        // accumulate counts
        for (int x = 0; x < 2; x++) for (int y = 0; y < 2; y++) {
            cnt1 += dp_cur[1][x][y];
            cnt2 += dp_cur[2][x][y];
        }
    }
    // total extra from best swap
    mint extra = cnt1 + cnt2 * 2;
    mint ans = base_sum + extra;
    cout << ans.val() << "\n";
    return 0;
}
0