結果

問題 No.3119 A Little Cheat
ユーザー Yu_212
提出日時 2025-04-19 13:25:32
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 85 ms / 2,000 ms
コード長 6,018 bytes
コンパイル時間 3,366 ms
コンパイル使用メモリ 289,824 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-04-19 13:25:41
合計ジャッジ時間 7,697 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 49
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std;
using ll = long long;

const int iinf = 1e9;
const ll inf = 1e18;
template<ll mod>
struct Mint {
    using M=Mint; ll v;
    M& put(ll x) { v=(x<mod)?x:x-mod; return *this; }
    Mint(ll x=0) { put(x%mod+mod); }
    M operator+(M m) {return M().put(v+m.v);}
    M operator-(M m) {return M().put(v+mod-m.v);}
    M operator*(M m) {return M().put(v*m.v%mod);}
    M operator/(M m) {return M().put(v*m.inv().v%mod);}
    M operator+=(M m) { return put(v+m.v); }
    M operator-=(M m) { return put(v+mod-m.v); }
    M operator*=(M m) { return put(v*m.v%mod); }
    M operator/=(M m) { return put(v*m.inv().v%mod); }
    bool operator==(M m) { return v==m.v; }
    M pow(ll m) const {
        M x=v, res=1;
        while (m) {
            if (m&1) res=res*x;
            x=x*x; m>>=1;
        }
        return res;
    }
    M inv() { return pow(mod-2); }
};
template<ll mod>
ostream&operator<<(ostream&o,Mint<mod>v){return o<<v.v;}


template<typename T>
ostream& operator<<(ostream &o, vector<T> v) {
    for (int i = 0; i < v.size(); i++)
        o << v[i] << (i+1<v.size()?" ":"");
    return o;
}
template <typename T>
struct SegTree {
    using F = function<T(T, T)>;
    int n;
    F f;
    T ti;
    vector<T> dat;
    SegTree() {}
    SegTree(F f, T ti,int num) : f(f), ti(ti) {
        n = max(__bit_ceil(num), 1);
        dat.assign(n << 1, ti);
    }
    SegTree(F f,T ti,vector<T>&v):SegTree(f,ti,v.size()){
        for (int i = 0; i < v.size(); i++)
            dat[n + i] = v[i];
        for(int i=n-1;i;i--) dat[i]=f(dat[i*2], dat[i*2+1]);
    }
    void set_val(int k, T x) {
        dat[k += n] = x;
        while(k >>= 1) dat[k] = f(dat[k*2], dat[k*2+1]);
    }
    T query(int a, int b) {
        if (a >= b) return ti;
        T vl = ti, vr = ti;
        for (int l=a+n, r=b+n; l<r; l>>=1, r>>=1) {
            if (l & 1) vl = f(vl, dat[l++]);
            if (r & 1) vr = f(dat[--r], vr);
        }
        return f(vl, vr);
    }
};

const int MOD = 998244353;
using mint = Mint<MOD>;

ll modpow(ll a, ll e){
    ll r = 1;
    a %= MOD;
    while(e){
        if(e&1) r = r * a % MOD;
        a = a * a % MOD;
        e >>= 1;
    }
    return r;
}

int main() {
    cin.tie(0)->sync_with_stdio(false);

    int N;
    ll M;
    cin >> N >> M;
    vector<ll> A(N+2);
    for(int i = 1; i <= N; i++) cin >> A[i];
    // sentinel values for boundaries
    A[0] = 0;
    A[N+1] = M+1;

    // 1) Term1 = M^{N-1} * sum_{i=1}^N (M - A[i])
    ll sumMA = 0;
    for(int i = 1; i <= N; i++){
        sumMA = (sumMA + (M - A[i]) % MOD) % MOD;
    }
    ll powM_N   = modpow(M, N);
    ll powM_Nm1 = modpow(M, N-1);
    ll term1 = powM_Nm1 * sumMA % MOD;

    // DP states: dp_prev holds (prev_right_category, curr_right_category) -> count
    using State = pair<int,int>;
    vector<pair<State,ll>> dp_prev, dp_cur;

    // build 3 intervals: [1,L], (L,R], (R+1,M]
    auto build3 = [&](ll L, ll R){
        vector<pair<ll,ll>> parts(3);
        parts[0] = {1, min(L, R)};
        parts[1] = {min(L, R)+1, max(L, R)};
        parts[2] = {max(L, R)+1, M};
        for(auto &p: parts) if(p.first > p.second) p = {1,0};
        return parts;
    };

    // initialize i=1
    {
        auto pre = build3(A[0], A[1]);
        auto nxt = build3(A[1], A[2]);
        dp_prev.clear();
        for(int p = 0; p < 3; p++){
            auto &P = pre[p];
            if(P.first > P.second) continue;
            for(int n = 0; n < 3; n++){
                auto &Q = nxt[n];
                ll l = max(P.first, Q.first);
                ll r = min(P.second, Q.second);
                if(l <= r){
                    dp_prev.emplace_back(State(p,n), (r - l + 1) % MOD);
                }
            }
        }
    }

    // main DP loop for i=2..N
    for(int i = 2; i <= N; i++){
        auto pre = build3(A[i-1], A[i]);
        auto nxt = build3(A[i],   A[i+1]);
        bool forward = (A[i] > A[i-1]);

        // prepare next states
        vector<tuple<int,int,ll>> states;
        for(int p = 0; p < 3; p++){
            auto &P = pre[p]; if(P.first> P.second) continue;
            for(int n = 0; n < 3; n++){
                auto &Q = nxt[n]; if(Q.first>Q.second) continue;
                ll l = max(P.first, Q.first);
                ll r = min(P.second, Q.second);
                if(l <= r){
                    states.emplace_back(p, n, (r - l + 1) % MOD);
                }
            }
        }
        dp_cur.clear();
        for(auto &st: states){ dp_cur.emplace_back(State(get<0>(st), get<1>(st)), 0LL); }

        // transitions
        for(auto &prv: dp_prev){
            int prev_rp = prv.first.first;   // category for B[i-1] wrt (A[i-2],A[i-1]) - unused
            int prev_rn = prv.first.second;  // category for B[i-1] wrt (A[i-1],A[i])
            ll cnt = prv.second;
            if(cnt == 0) continue;
            for(size_t idx = 0; idx < states.size(); idx++){
                int rp_cur = get<0>(states[idx]);  // B[i] category wrt (A[i-1],A[i])
                int rn_cur = get<1>(states[idx]); // B[i] category wrt (A[i],A[i+1])
                ll w = get<2>(states[idx]);
                bool bad = false;
                if(forward){
                    // forbid if B[i] in cat1 AND B[i-1] in cat0 or cat2
                    if(rp_cur==1 && (prev_rn==0 || prev_rn==2)) bad = true;
                } else {
                    // forbid if B[i-1] in cat1 AND B[i] in cat0 or cat2
                    if(prev_rn==1 && (rp_cur==0 || rp_cur==2)) bad = true;
                }
                if(bad) continue;
                dp_cur[idx].second = (dp_cur[idx].second + cnt * w) % MOD;
            }
        }
        dp_prev.clear();
        for(auto &c: dp_cur) if(c.second) dp_prev.push_back(c);
    }

    // sum counts for no-improvement sequences P
    ll P = 0;
    for(auto &pr: dp_prev) P = (P + pr.second) % MOD;

    // final answer
    ll ans = (term1 + (powM_N - P + MOD) % MOD) % MOD;
    cout << ans << "\n";
    return 0;
}
0