N,M = map(int,input().split()) dp = [[0] * (M + 1) for _ in range(N)] Sum = [[0] * (M + 1) for _ in range(N+1)] for i in range(2,M+1): dp[0][i] = 1 Sum[1][i] = 1 for i in range(1,M+1): Sum[1][i] += Sum[1][i-1] P = 998244353 for i in range(1,N): for v in range(2,M+1): inf = max(0,i-v+1) tmp = Sum[i][-1] - Sum[inf][-1] - (Sum[i][v] - Sum[inf][v]) + Sum[i][v-1] - Sum[inf][v-1] if v > i + 1: tmp += 1 dp[i][v] = tmp % P Tum = 0 for v in range(2,M+1): Tum = (Tum + dp[i][v]) % P Sum[i+1][v] = Sum[i][v] + Tum for v in range(2,M+1): #Sum[i+1][v] += Sum[i+1][v-1] Sum[i+1][v] %= P ans = pow(M,N,P) for v in range(2,M+1): ans = (ans - dp[-1][v]) % P print(ans) #print(dp) #print(Sum)