MOD = 998244353 n, m = map(int, input().split()) dp = [[0] * (m + 1) for _ in range(n + 1)] dp[0][0] = 1 tot = [0] * (n + 1) tot[0] = 1 powm = [1] for _ in range(n): powm.append(powm[-1] * m % MOD) ans = 0 for i in range(1, n + 1): ans += tot[i - 1] * powm[n - i] ans %= MOD for j in range(2, m + 1): dp[i][j] = tot[i - 1] if i >= j: dp[i][j] -= tot[i - j] - dp[i - j][j] ans += (tot[i - j] - dp[i - j][j]) * powm[n - i] ans %= MOD dp[i][j] %= MOD tot[i] += dp[i][j] tot[i] %= MOD print(ans)