結果
問題 |
No.3119 A Little Cheat
|
ユーザー |
|
提出日時 | 2025-04-24 15:07:03 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,473 bytes |
コンパイル時間 | 2,506 ms |
コンパイル使用メモリ | 202,876 KB |
実行使用メモリ | 14,356 KB |
最終ジャッジ日時 | 2025-04-24 15:07:10 |
合計ジャッジ時間 | 6,312 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | WA * 3 |
other | AC * 3 WA * 46 |
ソースコード
#include <bits/stdc++.h> using namespace std; const int MOD = 998244353; using ll = long long; // a^e mod MOD ll modpow(ll a, ll e) { ll r = 1; while (e) { if (e & 1) r = r * a % MOD; a = a * a % MOD; e >>= 1; } return r; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int N; ll M; cin >> N >> M; vector<ll> A(N + 2); for (int i = 1; i <= N; i++) cin >> A[i]; // M^N and M^(N-1) ll Mn = modpow(M, N); ll Mn1 = Mn * modpow(M, MOD - 2) % MOD; // M^{N-1} // sum S0 = M^{N-1} * sum_i(M - A[i]) ll sumMA = 0; for (int i = 1; i <= N; i++) { sumMA = (sumMA + (M - A[i]) % MOD) % MOD; } ll sumS0 = Mn1 * sumMA % MOD; // Z1: 2-state DP vector<ll> dp1(2), dp2(2); dp1[0] = A[1] % MOD; dp1[1] = (M - A[1]) % MOD; for (int i = 1; i < N; i++) { dp2[0] = dp2[1] = 0; bool inc = (A[i] < A[i + 1]); ll c0 = A[i + 1] % MOD; ll c1 = (M - A[i + 1]) % MOD; for (int s = 0; s < 2; s++) { for (int t = 0; t < 2; t++) { if (s == 0 && t == 1 && inc) continue; ll ways = (t == 0 ? c0 : c1); dp2[t] = (dp2[t] + dp1[s] * ways) % MOD; } } dp1.swap(dp2); } ll Z1 = (dp1[0] + dp1[1]) % MOD; // Z0: 3-state DP static int delta_lt[3][3] = { { 0, 1, 2}, {-1, 0, 1}, { 0, 1, 0} }; static int delta_gt[3][3] = { { 0,-1, 0}, { 1, 0, 1}, { 0,-1, 0} }; // ranges[i][k] = {L, R} for k=0,1,2 relative to (A[i], A[i+1]) vector< array<pair<ll,ll>,3> > ranges(N + 1); for (int i = 1; i < N; i++) { ll a = A[i], b = A[i + 1]; ll U = min(a, b), V = max(a, b); ranges[i][0] = {1, U}; ranges[i][1] = {U + 1, V}; ranges[i][2] = {V + 1, M}; } vector<ll> dp0(3), dp3(3); // 初期値: i=1 for (int k = 0; k < 3; k++) { auto [L, R] = ranges[1][k]; dp0[k] = max(0LL, R - L + 1) % MOD; } // i = 1..N-2 for (int i = 1; i < N - 1; i++) { dp3 = {0, 0, 0}; int (*delt)[3] = (A[i] < A[i + 1]) ? delta_lt : delta_gt; for (int k = 0; k < 3; k++) { if (dp0[k] == 0) continue; for (int l = 0; l < 3; l++) { if (delt[k][l] > 0) continue; auto [L1, R1] = ranges[i][l]; if (R1 < L1) continue; for (int m = 0; m < 3; m++) { auto [L2, R2] = ranges[i + 1][m]; ll L = max(L1, L2), R = min(R1, R2); if (R < L) continue; ll cnt = (R - L + 1) % MOD; dp3[m] = (dp3[m] + dp0[k] * cnt) % MOD; } } } dp0.swap(dp3); } // 最後: i=N-1 -> N ll Z0 = 0; { int i = N - 1; int (*delt)[3] = (A[i] < A[i + 1]) ? delta_lt : delta_gt; for (int k = 0; k < 3; k++) { if (dp0[k] == 0) continue; for (int l = 0; l < 3; l++) { if (delt[k][l] > 0) continue; auto [L, R] = ranges[i][l]; if (R < L) continue; ll cnt = (R - L + 1) % MOD; Z0 = (Z0 + dp0[k] * cnt) % MOD; } } } ll ans = (sumS0 + 2 * Mn - Z0 - Z1) % MOD; if (ans < 0) ans += MOD; cout << ans << "\n"; return 0; }