結果
| 問題 |
No.3119 A Little Cheat
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-04-24 15:07:03 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,473 bytes |
| コンパイル時間 | 2,506 ms |
| コンパイル使用メモリ | 202,876 KB |
| 実行使用メモリ | 14,356 KB |
| 最終ジャッジ日時 | 2025-04-24 15:07:10 |
| 合計ジャッジ時間 | 6,312 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | WA * 3 |
| other | AC * 3 WA * 46 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
using ll = long long;
// a^e mod MOD
ll modpow(ll a, ll e) {
ll r = 1;
while (e) {
if (e & 1) r = r * a % MOD;
a = a * a % MOD;
e >>= 1;
}
return r;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int N;
ll M;
cin >> N >> M;
vector<ll> A(N + 2);
for (int i = 1; i <= N; i++) cin >> A[i];
// M^N and M^(N-1)
ll Mn = modpow(M, N);
ll Mn1 = Mn * modpow(M, MOD - 2) % MOD; // M^{N-1}
// sum S0 = M^{N-1} * sum_i(M - A[i])
ll sumMA = 0;
for (int i = 1; i <= N; i++) {
sumMA = (sumMA + (M - A[i]) % MOD) % MOD;
}
ll sumS0 = Mn1 * sumMA % MOD;
// Z1: 2-state DP
vector<ll> dp1(2), dp2(2);
dp1[0] = A[1] % MOD;
dp1[1] = (M - A[1]) % MOD;
for (int i = 1; i < N; i++) {
dp2[0] = dp2[1] = 0;
bool inc = (A[i] < A[i + 1]);
ll c0 = A[i + 1] % MOD;
ll c1 = (M - A[i + 1]) % MOD;
for (int s = 0; s < 2; s++) {
for (int t = 0; t < 2; t++) {
if (s == 0 && t == 1 && inc) continue;
ll ways = (t == 0 ? c0 : c1);
dp2[t] = (dp2[t] + dp1[s] * ways) % MOD;
}
}
dp1.swap(dp2);
}
ll Z1 = (dp1[0] + dp1[1]) % MOD;
// Z0: 3-state DP
static int delta_lt[3][3] = {
{ 0, 1, 2},
{-1, 0, 1},
{ 0, 1, 0}
};
static int delta_gt[3][3] = {
{ 0,-1, 0},
{ 1, 0, 1},
{ 0,-1, 0}
};
// ranges[i][k] = {L, R} for k=0,1,2 relative to (A[i], A[i+1])
vector< array<pair<ll,ll>,3> > ranges(N + 1);
for (int i = 1; i < N; i++) {
ll a = A[i], b = A[i + 1];
ll U = min(a, b), V = max(a, b);
ranges[i][0] = {1, U};
ranges[i][1] = {U + 1, V};
ranges[i][2] = {V + 1, M};
}
vector<ll> dp0(3), dp3(3);
// 初期値: i=1
for (int k = 0; k < 3; k++) {
auto [L, R] = ranges[1][k];
dp0[k] = max(0LL, R - L + 1) % MOD;
}
// i = 1..N-2
for (int i = 1; i < N - 1; i++) {
dp3 = {0, 0, 0};
int (*delt)[3] = (A[i] < A[i + 1]) ? delta_lt : delta_gt;
for (int k = 0; k < 3; k++) {
if (dp0[k] == 0) continue;
for (int l = 0; l < 3; l++) {
if (delt[k][l] > 0) continue;
auto [L1, R1] = ranges[i][l];
if (R1 < L1) continue;
for (int m = 0; m < 3; m++) {
auto [L2, R2] = ranges[i + 1][m];
ll L = max(L1, L2), R = min(R1, R2);
if (R < L) continue;
ll cnt = (R - L + 1) % MOD;
dp3[m] = (dp3[m] + dp0[k] * cnt) % MOD;
}
}
}
dp0.swap(dp3);
}
// 最後: i=N-1 -> N
ll Z0 = 0;
{
int i = N - 1;
int (*delt)[3] = (A[i] < A[i + 1]) ? delta_lt : delta_gt;
for (int k = 0; k < 3; k++) {
if (dp0[k] == 0) continue;
for (int l = 0; l < 3; l++) {
if (delt[k][l] > 0) continue;
auto [L, R] = ranges[i][l];
if (R < L) continue;
ll cnt = (R - L + 1) % MOD;
Z0 = (Z0 + dp0[k] * cnt) % MOD;
}
}
}
ll ans = (sumS0 + 2 * Mn - Z0 - Z1) % MOD;
if (ans < 0) ans += MOD;
cout << ans << "\n";
return 0;
}