N,M = map(int,input().split())

dp = [[0] * (M + 1) for _ in range(N)]
Sum = [[0] * (M + 1) for _ in range(N+1)]
for i in range(2,M+1):
    dp[0][i] = 1
    Sum[1][i] = 1
for i in range(1,M+1):
    Sum[1][i] += Sum[1][i-1]
P = 998244353
for i in range(1,N):
    for v in range(2,M+1):
        inf = max(0,i-v+1)
        tmp = Sum[i][-1] - Sum[inf][-1] - (Sum[i][v] - Sum[inf][v]) + Sum[i][v-1] - Sum[inf][v-1]
        if v > i + 1:
            tmp += 1
        dp[i][v] = tmp % P
    Tum = 0
    for v in range(2,M+1):
        Tum = (Tum + dp[i][v]) % P
        Sum[i+1][v] = Sum[i][v] + Tum
    for v in range(2,M+1):
        #Sum[i+1][v] += Sum[i+1][v-1]
        Sum[i+1][v] %= P
ans = pow(M,N,P)
for v in range(2,M+1):
    ans = (ans - dp[-1][v]) % P
print(ans)
#print(dp)
#print(Sum)