結果

問題 No.3119 A Little Cheat
ユーザー keigo kuwata
提出日時 2025-04-24 15:07:03
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 3,473 bytes
コンパイル時間 2,506 ms
コンパイル使用メモリ 202,876 KB
実行使用メモリ 14,356 KB
最終ジャッジ日時 2025-04-24 15:07:10
合計ジャッジ時間 6,312 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 3
other AC * 3 WA * 46
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
using ll = long long;

// a^e mod MOD
ll modpow(ll a, ll e) {
    ll r = 1;
    while (e) {
        if (e & 1) r = r * a % MOD;
        a = a * a % MOD;
        e >>= 1;
    }
    return r;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N;
    ll M;
    cin >> N >> M;
    vector<ll> A(N + 2);
    for (int i = 1; i <= N; i++) cin >> A[i];

    // M^N and M^(N-1)
    ll Mn = modpow(M, N);
    ll Mn1 = Mn * modpow(M, MOD - 2) % MOD; // M^{N-1}

    // sum S0 = M^{N-1} * sum_i(M - A[i])
    ll sumMA = 0;
    for (int i = 1; i <= N; i++) {
        sumMA = (sumMA + (M - A[i]) % MOD) % MOD;
    }
    ll sumS0 = Mn1 * sumMA % MOD;

    // Z1: 2-state DP
    vector<ll> dp1(2), dp2(2);
    dp1[0] = A[1] % MOD;
    dp1[1] = (M - A[1]) % MOD;
    for (int i = 1; i < N; i++) {
        dp2[0] = dp2[1] = 0;
        bool inc = (A[i] < A[i + 1]);
        ll c0 = A[i + 1] % MOD;
        ll c1 = (M - A[i + 1]) % MOD;
        for (int s = 0; s < 2; s++) {
            for (int t = 0; t < 2; t++) {
                if (s == 0 && t == 1 && inc) continue;
                ll ways = (t == 0 ? c0 : c1);
                dp2[t] = (dp2[t] + dp1[s] * ways) % MOD;
            }
        }
        dp1.swap(dp2);
    }
    ll Z1 = (dp1[0] + dp1[1]) % MOD;

    // Z0: 3-state DP
    static int delta_lt[3][3] = {
        { 0, 1, 2},
        {-1, 0, 1},
        { 0, 1, 0}
    };
    static int delta_gt[3][3] = {
        { 0,-1, 0},
        { 1, 0, 1},
        { 0,-1, 0}
    };
    // ranges[i][k] = {L, R} for k=0,1,2 relative to (A[i], A[i+1])
    vector< array<pair<ll,ll>,3> > ranges(N + 1);
    for (int i = 1; i < N; i++) {
        ll a = A[i], b = A[i + 1];
        ll U = min(a, b), V = max(a, b);
        ranges[i][0] = {1, U};
        ranges[i][1] = {U + 1, V};
        ranges[i][2] = {V + 1, M};
    }
    vector<ll> dp0(3), dp3(3);
    // 初期値: i=1
    for (int k = 0; k < 3; k++) {
        auto [L, R] = ranges[1][k];
        dp0[k] = max(0LL, R - L + 1) % MOD;
    }
    // i = 1..N-2
    for (int i = 1; i < N - 1; i++) {
        dp3 = {0, 0, 0};
        int (*delt)[3] = (A[i] < A[i + 1]) ? delta_lt : delta_gt;
        for (int k = 0; k < 3; k++) {
            if (dp0[k] == 0) continue;
            for (int l = 0; l < 3; l++) {
                if (delt[k][l] > 0) continue;
                auto [L1, R1] = ranges[i][l];
                if (R1 < L1) continue;
                for (int m = 0; m < 3; m++) {
                    auto [L2, R2] = ranges[i + 1][m];
                    ll L = max(L1, L2), R = min(R1, R2);
                    if (R < L) continue;
                    ll cnt = (R - L + 1) % MOD;
                    dp3[m] = (dp3[m] + dp0[k] * cnt) % MOD;
                }
            }
        }
        dp0.swap(dp3);
    }
    // 最後: i=N-1 -> N
    ll Z0 = 0;
    {
        int i = N - 1;
        int (*delt)[3] = (A[i] < A[i + 1]) ? delta_lt : delta_gt;
        for (int k = 0; k < 3; k++) {
            if (dp0[k] == 0) continue;
            for (int l = 0; l < 3; l++) {
                if (delt[k][l] > 0) continue;
                auto [L, R] = ranges[i][l];
                if (R < L) continue;
                ll cnt = (R - L + 1) % MOD;
                Z0 = (Z0 + dp0[k] * cnt) % MOD;
            }
        }
    }

    ll ans = (sumS0 + 2 * Mn - Z0 - Z1) % MOD;
    if (ans < 0) ans += MOD;
    cout << ans << "\n";
    return 0;
}
0