結果

問題 No.3119 A Little Cheat
ユーザー shibh308
提出日時 2025-04-20 18:59:25
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 73 ms / 2,000 ms
コード長 4,817 bytes
コンパイル時間 2,074 ms
コンパイル使用メモリ 206,356 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-04-20 18:59:33
合計ジャッジ時間 6,634 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 49
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
static const ll MOD = 998244353;

inline ll mod_add(ll a, ll b) {
    a += b;
    if (a >= MOD) a -= MOD;
    return a;
}

inline ll mod_mul(ll a, ll b) {
    return (a % MOD) * (b % MOD) % MOD;
}

inline ll overlap(int a1, int b1, int a2, int b2) {
    int lo = max(a1, a2);
    int hi = min(b1, b2);
    return (lo <= hi) ? (hi - lo + 1) : 0;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N, M;
    cin >> N >> M;
    vector<int> A(N);
    for(int i = 0; i < N; i++){
        cin >> A[i];
    }

    // Precompute powers of M mod MOD
    vector<ll> powM(N+1, 1);
    ll M_mod = M % MOD;
    for(int i = 1; i <= N; i++){
        powM[i] = mod_mul(powM[i-1], M_mod);
    }

    // 1) Base score: sum (M - A[i]) * M^(N-1)
    ll base = 0;
    for(int i = 0; i < N; i++){
        base = mod_add(base,
                       mod_mul((M - A[i] + MOD) % MOD,
                               powM[N-1]));
    }
    if(N == 1){
        cout << base << "\n";
        return 0;
    }

    // DP states: counts of prefixes ending in small/mid/large region
    int t0 = min(A[0], A[1]);
    int T0 = max(A[0], A[1]);
    ll dp_sm = t0;
    ll dp_md = T0 - t0;
    ll dp_lg = M - T0;

    ll X = 0;
    for(int i = 0; i < N - 1; i++){
        // current thresholds
        int ti = min(A[i], A[i+1]);
        int Ti = max(A[i], A[i+1]);
        int sm_i = ti;
        int md_i = Ti - ti;
        int lg_i = M - Ti;

        // number of free tail choices
        ll tail = (i + 2 <= N) ? powM[N - i - 2] : 1LL;

        // add contribution for first improvement at i
        if(A[i] < A[i+1]){
            // improve by moving into mid_i
            ll ways = mod_add(dp_sm, dp_lg);
            X = mod_add(X, mod_mul(mod_mul(ways, md_i), tail));
        } else {
            // improve by moving into sm_i or lg_i
            ll ways = dp_md % MOD;
            ll sum = (sm_i + lg_i) % MOD;
            X = mod_add(X, mod_mul(mod_mul(ways, sum), tail));
        }

        // no need to update dp after last pair
        if(i == N - 2) break;

        // thresholds for next position
        int tn = min(A[i+1], A[i+2]);
        int Tn = max(A[i+1], A[i+2]);
        int sm_n = tn;
        int md_n = Tn - tn;
        int lg_n = M - Tn;

        // build allowed ranges for B_{i+1} to avoid early improvement
        // for each previous region "sm","md","lg"
        vector<pair<int,int>> allow_sm, allow_md, allow_lg;
        if(A[i] < A[i+1]){
            // cannot choose B_{i+1} in mid_i when coming from sm or lg
            allow_sm = {{1, ti}, {Ti+1, M}};
            allow_lg = allow_sm;
            allow_md = {{1, M}};
        } else {
            // cannot choose B_{i+1} in sm_i/lg_i when coming from md
            allow_md = {{ti+1, Ti}};
            allow_sm = {{1, M}};
            allow_lg = allow_sm;
        }

        // next intervals classification
        vector<pair<int,int>> next_sm = {{1, tn}};
        vector<pair<int,int>> next_md = {{tn+1, Tn}};
        vector<pair<int,int>> next_lg = {{Tn+1, M}};

        ll new_sm = 0, new_md = 0, new_lg = 0;

        // transition from "sm" region
        for(auto &seg : allow_sm){
            for(auto &nx : next_sm){
                new_sm += dp_sm * overlap(seg.first, seg.second, nx.first, nx.second);
            }
            for(auto &nx : next_md){
                new_md += dp_sm * overlap(seg.first, seg.second, nx.first, nx.second);
            }
            for(auto &nx : next_lg){
                new_lg += dp_sm * overlap(seg.first, seg.second, nx.first, nx.second);
            }
        }
        // from "md" region
        for(auto &seg : allow_md){
            for(auto &nx : next_sm){
                new_sm += dp_md * overlap(seg.first, seg.second, nx.first, nx.second);
            }
            for(auto &nx : next_md){
                new_md += dp_md * overlap(seg.first, seg.second, nx.first, nx.second);
            }
            for(auto &nx : next_lg){
                new_lg += dp_md * overlap(seg.first, seg.second, nx.first, nx.second);
            }
        }
        // from "lg" region
        for(auto &seg : allow_lg){
            for(auto &nx : next_sm){
                new_sm += dp_lg * overlap(seg.first, seg.second, nx.first, nx.second);
            }
            for(auto &nx : next_md){
                new_md += dp_lg * overlap(seg.first, seg.second, nx.first, nx.second);
            }
            for(auto &nx : next_lg){
                new_lg += dp_lg * overlap(seg.first, seg.second, nx.first, nx.second);
            }
        }

        dp_sm = new_sm % MOD;
        dp_md = new_md % MOD;
        dp_lg = new_lg % MOD;
    }

    ll answer = mod_add(base, X);
    cout << answer << "\n";
    return 0;
}
0