結果

問題 No.3119 A Little Cheat
ユーザー Naru820
提出日時 2025-08-11 13:39:15
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 63 ms / 2,000 ms
コード長 3,927 bytes
コンパイル時間 2,791 ms
コンパイル使用メモリ 284,296 KB
実行使用メモリ 12,688 KB
最終ジャッジ日時 2025-08-11 13:39:25
合計ジャッジ時間 9,254 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 49
権限があれば一括ダウンロードができます

ソースコード

diff #

// by GPT-5
#include <bits/stdc++.h>
using namespace std;
using int64 = long long;
const int64 MOD = 998244353;

int64 modpow(int64 a, long long e) {
    int64 r = 1 % MOD;
    a %= MOD;
    while (e > 0) {
        if (e & 1) r = (__int128)r * a % MOD;
        a = (__int128)a * a % MOD;
        e >>= 1;
    }
    return r;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int N;
    long long M;
    if(!(cin >> N >> M)) return 0;
    vector<long long> A(N);
    for (int i = 0; i < N; ++i) cin >> A[i];

    // S0 = M^{N-1} * sum_i (M - A_i)  (mod MOD)
    int64 powM_N1 = modpow(M, N - 1);
    long long sumMminusA = 0;
    for (int i = 0; i < N; ++i) sumMminusA += (M - A[i]);
    sumMminusA %= MOD;
    int64 S0 = (__int128)powM_N1 * (sumMminusA % MOD) % MOD;

    // s_j = size of I_j = |A_{j+1} - A_j|
    vector<long long> s(max(0, N-1));
    for (int j = 0; j < N-1; ++j) s[j] = llabs(A[j+1] - A[j]);

    // cnt[i][state], state index = l*2 + r  (l: in I_{i-1}, r: in I_i)
    vector<array<long long,4>> cnt(N);
    for (int i = 0; i < N; ++i) {
        long long s_left = (i-1 >= 0 ? s[i-1] : 0);
        long long s_right = (i < N-1 ? s[i] : 0);
        long long t = 0; // overlap size L 竏ゥ R
        if (i-1 >= 0 && i < N-1) {
            long long L1 = min(A[i-1], A[i]) + 1;
            long long R1 = max(A[i-1], A[i]);
            long long L2 = min(A[i], A[i+1]) + 1;
            long long R2 = max(A[i], A[i+1]);
            long long L = max(L1, L2);
            long long R = min(R1, R2);
            if (R >= L) t = (R - L + 1);
            else t = 0;
        } else t = 0;
        long long both = t;
        long long only_left = s_left - t;
        long long only_right = s_right - t;
        long long none = M - (s_left + s_right - t);
        if (none < 0) none = 0; // safety (shouldn't happen)
        cnt[i][0] = none;        // l=0, r=0
        cnt[i][1] = only_right;  // l=0, r=1
        cnt[i][2] = only_left;   // l=1, r=0
        cnt[i][3] = both;        // l=1, r=1
    }

    // DP to count sequences with NO ホ・1 anywhere (i.e. no forbidden transitions)
    // dp_prev[state] = number of sequences up to position i ending in 'state'
    vector<int64> dp_prev(4), dp_cur(4);
    for (int st = 0; st < 4; ++st) dp_prev[st] = (cnt[0][st] % MOD);

    for (int i = 1; i < N; ++i) {
        fill(dp_cur.begin(), dp_cur.end(), 0);
        // forbidden pattern for pair (i-1): depends on A[i-1] vs A[i]
        // if A[i-1] < A[i]: forbid (prev.r == 0 && curr.l == 1)
        // if A[i-1] > A[i]: forbid (prev.r == 1 && curr.l == 0)
        int forbid_prev_r = -1, forbid_curr_l = -1;
        if (A[i-1] < A[i]) { forbid_prev_r = 0; forbid_curr_l = 1; }
        else if (A[i-1] > A[i]) { forbid_prev_r = 1; forbid_curr_l = 0; }
        // else equal -> no forbidden pattern

        for (int prev = 0; prev < 4; ++prev) {
            int l_prev = prev / 2;
            int r_prev = prev % 2;
            if (dp_prev[prev] == 0) continue;
            for (int cur = 0; cur < 4; ++cur) {
                int l_cur = cur / 2;
                int r_cur = cur % 2;
                if (forbid_prev_r != -1) {
                    if (r_prev == forbid_prev_r && l_cur == forbid_curr_l) continue; // forbidden
                }
                // add dp_prev[prev] * cnt[i][cur]
                int64 add = dp_prev[prev];
                int64 times = cnt[i][cur] % MOD;
                __int128 prod = (__int128)add * times;
                dp_cur[cur] = (dp_cur[cur] + (int64)(prod % MOD)) % MOD;
            }
        }
        dp_prev = dp_cur;
    }

    int64 count_no_forbid = 0;
    for (int st = 0; st < 4; ++st) count_no_forbid = (count_no_forbid + dp_prev[st]) % MOD;

    int64 total = modpow(M, N);
    int64 C1 = (total - count_no_forbid) % MOD;
    if (C1 < 0) C1 += MOD;

    int64 answer = (S0 + C1) % MOD;
    cout << answer << "\n";
    return 0;
}
0