n, m = map(int, input().split()) mod = 998244353 dp = [[0 for j in range(m + 1)] for i in range(n + 1)] for i in range(n + 1): dp[i][1] = 1 for j in range(1, m + 1): dp[0][j] = 1 for i in range(1, n + 1): for j in range(2, m + 1): dp[i][j] += dp[i][j - 1] dp[i][j] += dp[i - 1][m] if i - j >= 0: dp[i][j] -= dp[i - j][j - 1] dp[i][j] += dp[i - j][j] dp[i][j] -= dp[i - j][m] dp[i][j] %= mod ans = (pow(m, n, mod) - (dp[n][m] - dp[n - 1][m])) % mod print(ans)