結果

問題 No.3119 A Little Cheat
ユーザー AKI
提出日時 2025-04-20 13:19:01
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 41 ms / 2,000 ms
コード長 2,539 bytes
コンパイル時間 2,404 ms
コンパイル使用メモリ 199,272 KB
実行使用メモリ 10,240 KB
最終ジャッジ日時 2025-04-20 13:19:07
合計ジャッジ時間 6,308 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 49
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
const long long MOD = 998244353;

/* a^e (mod MOD) */
long long mod_pow(long long a,long long e){
    long long r=1%MOD, x=a%MOD;
    for(;e;e>>=1,x=(__int128)x*x%MOD) if(e&1) r=(__int128)r*x%MOD;
    return r;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N; long long M;
    if(!(cin>>N>>M)) return 0;
    vector<long long> A(N);
    for(auto &v:A) cin>>v;

    /* --- BASE 部 --- */
    long long sum=0;
    for(auto a:A){ sum += (M-a)%MOD; if(sum>=MOD) sum-=MOD; }
    long long powM_N1 = mod_pow(M, N-1);
    long long BASE = (__int128)powM_N1*sum % MOD;

    if(N==1){           // スワップが無い
        cout<<BASE<<"\n";
        return 0;
    }

    /* 区間情報 */
    vector<long long> lo(N-1), hi(N-1), len(N-1);
    vector<int> sig(N-1);       // +1,0,-1
    for(int i=0;i<N-1;i++){
        lo[i]=min(A[i],A[i+1]);
        hi[i]=max(A[i],A[i+1]);
        len[i]=hi[i]-lo[i];
        sig[i]=(A[i]<A[i+1])?1:(A[i]>A[i+1]?-1:0);
    }

    long long Mmod = M%MOD;
    long long Sall = Mmod;          // B1 … 何でも OK
    long long Sin  = len[0]%MOD;    // B1 が I0 の “中”

    /* 1 本ずつ処理 */
    for(int i=0;i<N-1;i++){
        long long lcur=len[i];
        int s=sig[i];

        long long lnext   = (i==N-2)?0:len[i+1];
        long long inter   = 0;
        if(i!=N-2){
            long long x=max(lo[i],lo[i+1]);
            long long y=min(hi[i],hi[i+1]);
            if(y>x) inter=y-x;      // (x, y]
        }

        long long Sall_n, Sin_n=0;
        if(s==1){           /* 上り */
            Sall_n = ( (__int128)lcur*Sin + (__int128)(M-lcur)*Sall ) % MOD;
            if(i!=N-2)
                Sin_n = ( (__int128)inter*Sin + (__int128)(lnext-inter)*Sall )%MOD;
        }else if(s==-1){    /* 下り */
            long long Sout = (Sall - Sin + MOD)%MOD;
            Sall_n = ( (__int128)lcur*Sall + (__int128)(M-lcur)*Sout ) % MOD;
            if(i!=N-2)
                Sin_n = ( (__int128)inter*Sall + (__int128)(lnext-inter)*Sout )%MOD;
        }else{              /* 同値 */
            Sall_n = (__int128)Mmod*Sall % MOD;
            if(i!=N-2)  Sin_n = (__int128)lnext*Sall % MOD;
        }
        Sall = Sall_n;  Sin = Sin_n;        // advance
    }
    long long T = Sall;                 // 禁止パターン無しの列数
    long long powM_N = (__int128)powM_N1*Mmod % MOD;
    long long extra  = (powM_N - T + MOD) % MOD;

    cout << (BASE + extra) % MOD << '\n';
    return 0;
}
0