結果
問題 |
No.3119 A Little Cheat
|
ユーザー |
|
提出日時 | 2025-04-15 17:20:01 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 474 ms / 2,000 ms |
コード長 | 2,717 bytes |
コンパイル時間 | 367 ms |
コンパイル使用メモリ | 82,656 KB |
実行使用メモリ | 108,852 KB |
最終ジャッジ日時 | 2025-04-15 17:20:21 |
合計ジャッジ時間 | 18,681 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 52 |
ソースコード
N, M = map(int, input().split()) A = list(map(int, input().split())) powM = [1]*N for i in range(N-1): powM[i+1] = powM[i]*M % 998244353 ans = 0 dp = [0]*4 dp[0] = min(A[0], A[1]) dp[1] = max(A[0], A[1])-min(A[0], A[1]) dp[2] = M-max(A[0], A[1]) ans = (ans+(M-A[0])*powM[N-1]) % 998244353 A.append(0) for i in range(1, N): m1 = [0, min(A[i-1], A[i]), max(A[i-1], A[i]), M] m2 = [0, min(A[i], A[i+1]), max(A[i], A[i+1]), M] ndp = [0]*4 if A[i-1] < A[i]: # swap : (0,1),(2,1) # +1 : (0,2),(1,2),(2,2) # none : (0,0),(1,0),(1,1),(2,0) for f in range(3): for s in range(3): if (f, s) == (0, 1) or (f, s) == (2, 1): c = m1[s+1]-m1[s] ndp[3] = (ndp[3]+dp[f]*c) % 998244353 ans = (ans+dp[f]*c % 998244353*powM[N-1-i]) % 998244353 elif s == 2: for t in range(3): c = max(0, min(m1[s+1], m2[t+1])-max(m1[s], m2[t])) ndp[t] = (ndp[t]+dp[f]*c) % 998244353 ans = (ans+dp[f]*c % 998244353*powM[N-1-i]) % 998244353 else: for t in range(3): c = max(0, min(m1[s+1], m2[t+1])-max(m1[s], m2[t])) ndp[t] = (ndp[t]+dp[f]*c) % 998244353 else: # swap+1 : (1,0) # swap+2 : (1,2) # +1 : (0,1),(0,2),(1,1),(2,1),(2,2) # none : (0,0),(2,0) for f in range(3): for s in range(3): if (f, s) == (1, 2): c = m1[s+1]-m1[s] ndp[3] = (ndp[3]+dp[f]*c) % 998244353 ans = (ans+dp[f]*c % 998244353*powM[N-1-i]*2) % 998244353 elif (f, s) == (1, 0): c = m1[s+1]-m1[s] ndp[3] = (ndp[3]+dp[f]*c) % 998244353 ans = (ans + dp[f]*c % 998244353*powM[N-1-i]) % 998244353 elif (f, s) == (0, 1) or (f, s) == (0, 2) or (f, s) == (1, 1) or (f, s) == (2, 1) or (f, s) == (2, 2): for t in range(3): c = max(0, min(m1[s+1], m2[t+1])-max(m1[s], m2[t])) ndp[t] = (ndp[t]+dp[f]*c) % 998244353 ans = (ans+dp[f]*c % 998244353*powM[N-1-i]) % 998244353 else: for t in range(3): c = max(0, min(m1[s+1], m2[t+1])-max(m1[s], m2[t])) ndp[t] = (ndp[t]+dp[f]*c) % 998244353 ans = (ans+dp[3]*(M-A[i]) % 998244353*powM[N-1-i]) % 998244353 ndp[3] = (ndp[3]+dp[3]*M) % 998244353 for i in range(4): dp[i] = ndp[i] print(ans)