n,m = map(int,input().split()) mod = 998244353 dp = [[0]*(m+1) for i in range(n+1)] dps = [0]*(m+1) for i in range(2,m+1): dp[1][i] = 1 dps[i] = 1 for i in range(2,n+1): s = sum(dps)%mod for j in range(2,m+1): dp[i][j] = (s - dps[j])%mod dps[j] += dp[i][j] - dp[max(0,i-j+1)][j] dps[j] %= mod ans = pow(m,n,mod)-sum(dps)%mod print(ans%mod)