MOD = 998244353 n, m = map(int, input().split()) dp = [[0] * (m + 1) for _ in range(n + 1)] # dp[i] = # of array meeting the conditions ending i no = [[0] * (m + 1) for _ in range(n + 1)] # no[i][j] = # of length i array ending j not meeting the conditions no[0][0] = 1 for i in range(n): no_sum = sum(no[i]) % MOD for k in range(1, min(m + 1, n - i + 1)): dp[i + k][k] = (dp[i + k][k] + no_sum - no[i][k]) % MOD for k in range(2, m + 1): no[i + 1][k] = (no_sum - dp[i + 1][k]) % MOD ans = 0 for i in range(1, n + 1): ans += sum(dp[i]) * pow(m, n - i, MOD) ans %= MOD print(ans)