結果

問題 No.3119 A Little Cheat
ユーザー tassei903
提出日時 2025-04-19 01:41:46
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 715 ms / 2,000 ms
コード長 3,042 bytes
コンパイル時間 775 ms
コンパイル使用メモリ 82,048 KB
実行使用メモリ 105,600 KB
最終ジャッジ日時 2025-04-19 01:42:16
合計ジャッジ時間 27,796 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 49
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = lambda :sys.stdin.readline()[:-1]
ni = lambda :int(input())
na = lambda :list(map(int,input().split()))
yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES")
no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO")

#######################################################################
def f(a, b):
    ans = 0
    for i in range(n):
        if a[i] < b[i]:
            ans += 1
    return ans
F = 0
def score(a, b):
    global F
    b = list(b)
    ans = f(a,b)
    A = f(a, b)
    for i in range(n-1):
        b[i], b[i+1] = b[i+1], b[i]
        ans = max(ans, f(a, b))
        b[i], b[i+1] = b[i+1], b[i]
    if A < ans:
        print(b)
        F += 1
    return ans

from itertools import product
def naive(n, m, a):
    ans = 0
    for p in product(range(1, m + 1), repeat=n):
        ans += score(a, p)
    return ans


def g(a0, a1, b0, b1):
    if a1 < a0:
        a0, a1 = a1, a0
        b0, b1 = b1, b0
    return b0 <= a0 < b1 <= a1 or a0 < b1 <= a1 < b0

def naive2(n, m, a):
    dp0 = [1] * m
    dp1 = [0] * m
    a = [i-1 for i in a]
    for i in range(n-1):
        ndp0 = [0] * m
        ndp1 = [0] * m
        for j in range(m):
            for k in range(m):
                x = g(a[i], a[i+1], j, k)
                if x:
                    ndp1[k] += dp0[j] + dp1[j]
                else:
                    ndp1[k] += dp1[j]
                    ndp0[k] += dp0[j]
        dp1 = ndp1
        dp0 = ndp0
    return sum(ndp1)

def h(a0, a1, a2, k, l):
    L = 0
    R = m
    if l % 2:
        L = max(L, a0)
    else:
        R = min(R, a0)
    
    if (l // 2) % 2:
        L = max(L, a1)
    else:
        R = min(R, a1)
    if l // 4:
        L = max(L, a2)
    else:
        R = min(R, a2)
    
    if a0 <= a1:
        X = (l % 4 == 1) and (k == 0 or k == 3)
    else:
        X = (k == 2) and (l % 4 == 0 or l % 4 == 3)
    # print(bin(l), L, R)
    return X, max(R - L, 0)

def solve(n, m, a):
    a.append(m)
    dp = [[0 for i in range(4)] for j in range(2)]
    dp[0][0] = min(a[0], a[1])
    dp[0][1] = max(a[1] - a[0], 0)
    dp[0][2] = max(a[0] - a[1], 0)
    dp[0][3] = m - max(a[0], a[1])

    # print(dp)
    for i in range(n-1):
        ndp =  [[0 for i in range(4)] for j in range(2)]
        for j in range(2):
            for k in range(4):
                if dp[j][k] == 0:
                    continue
                for l in range(8):
                    x, y = h(a[i], a[i+1], a[i+2], k, l)
                    if y:
                        # print(a[i], a[i+1], a[i+2], bin(k), bin(l), x, y)
                        ndp[j|x][l // 2] += dp[j][k] * y
                        ndp[j|x][l // 2] %= mod
        dp = ndp
        # print(dp)
    
    return sum(dp[1])
mod = 998244353
n, m = na()
a = [x for x in na()]
# print(naive(n, m, a))
# print(F)
# print(naive2(n, m, a))
ans = 0
for i in range(n):
    ans += m - a[i]
ans = ans * pow(m, n-1, mod) % mod
# print(ans)
# print(solve(n, m, a))
print((ans + solve(n, m, a)) % mod)
0