結果
| 問題 | 
                            No.3119 A Little Cheat
                             | 
                    
| コンテスト | |
| ユーザー | 
                             | 
                    
| 提出日時 | 2025-04-19 01:41:46 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                AC
                                 
                             
                            
                         | 
                    
| 実行時間 | 715 ms / 2,000 ms | 
| コード長 | 3,042 bytes | 
| コンパイル時間 | 775 ms | 
| コンパイル使用メモリ | 82,048 KB | 
| 実行使用メモリ | 105,600 KB | 
| 最終ジャッジ日時 | 2025-04-19 01:42:16 | 
| 合計ジャッジ時間 | 27,796 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge3 / judge1 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 49 | 
ソースコード
import sys
input = lambda :sys.stdin.readline()[:-1]
ni = lambda :int(input())
na = lambda :list(map(int,input().split()))
yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES")
no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO")
#######################################################################
def f(a, b):
    ans = 0
    for i in range(n):
        if a[i] < b[i]:
            ans += 1
    return ans
F = 0
def score(a, b):
    global F
    b = list(b)
    ans = f(a,b)
    A = f(a, b)
    for i in range(n-1):
        b[i], b[i+1] = b[i+1], b[i]
        ans = max(ans, f(a, b))
        b[i], b[i+1] = b[i+1], b[i]
    if A < ans:
        print(b)
        F += 1
    return ans
from itertools import product
def naive(n, m, a):
    ans = 0
    for p in product(range(1, m + 1), repeat=n):
        ans += score(a, p)
    return ans
def g(a0, a1, b0, b1):
    if a1 < a0:
        a0, a1 = a1, a0
        b0, b1 = b1, b0
    return b0 <= a0 < b1 <= a1 or a0 < b1 <= a1 < b0
def naive2(n, m, a):
    dp0 = [1] * m
    dp1 = [0] * m
    a = [i-1 for i in a]
    for i in range(n-1):
        ndp0 = [0] * m
        ndp1 = [0] * m
        for j in range(m):
            for k in range(m):
                x = g(a[i], a[i+1], j, k)
                if x:
                    ndp1[k] += dp0[j] + dp1[j]
                else:
                    ndp1[k] += dp1[j]
                    ndp0[k] += dp0[j]
        dp1 = ndp1
        dp0 = ndp0
    return sum(ndp1)
def h(a0, a1, a2, k, l):
    L = 0
    R = m
    if l % 2:
        L = max(L, a0)
    else:
        R = min(R, a0)
    
    if (l // 2) % 2:
        L = max(L, a1)
    else:
        R = min(R, a1)
    if l // 4:
        L = max(L, a2)
    else:
        R = min(R, a2)
    
    if a0 <= a1:
        X = (l % 4 == 1) and (k == 0 or k == 3)
    else:
        X = (k == 2) and (l % 4 == 0 or l % 4 == 3)
    # print(bin(l), L, R)
    return X, max(R - L, 0)
def solve(n, m, a):
    a.append(m)
    dp = [[0 for i in range(4)] for j in range(2)]
    dp[0][0] = min(a[0], a[1])
    dp[0][1] = max(a[1] - a[0], 0)
    dp[0][2] = max(a[0] - a[1], 0)
    dp[0][3] = m - max(a[0], a[1])
    # print(dp)
    for i in range(n-1):
        ndp =  [[0 for i in range(4)] for j in range(2)]
        for j in range(2):
            for k in range(4):
                if dp[j][k] == 0:
                    continue
                for l in range(8):
                    x, y = h(a[i], a[i+1], a[i+2], k, l)
                    if y:
                        # print(a[i], a[i+1], a[i+2], bin(k), bin(l), x, y)
                        ndp[j|x][l // 2] += dp[j][k] * y
                        ndp[j|x][l // 2] %= mod
        dp = ndp
        # print(dp)
    
    return sum(dp[1])
mod = 998244353
n, m = na()
a = [x for x in na()]
# print(naive(n, m, a))
# print(F)
# print(naive2(n, m, a))
ans = 0
for i in range(n):
    ans += m - a[i]
ans = ans * pow(m, n-1, mod) % mod
# print(ans)
# print(solve(n, m, a))
print((ans + solve(n, m, a)) % mod)