#include const int Mod = 998244353; int main() { int N, M; scanf("%d %d", &N, &M); int i, j; long long ans = 0, pow[3001], sum = 0, r_sum[3001] = {}, dp[3001] = {}, tmp[3001][3001]; for (i = 1, pow[0] = 1; i <= N; i++) pow[i] = pow[i-1] * M % Mod; for (j = 2; j <= M; j++) { tmp[1][j] = 1; r_sum[j] = 1; dp[j] = 1; } for (i = 2, sum = M - 1, ans = pow[N-1]; i <= N; i++) { ans += (dp[i] + sum) % Mod * pow[N-i] % Mod; for (j = 2; j <= M; j++) { tmp[i][j] = sum - r_sum[j]; if (tmp[i][j] < 0) tmp[i][j] += Mod; if (i + j - 1 <= N) dp[i+j-1] += tmp[i][j]; r_sum[j] += tmp[i][j]; if (i >= j) r_sum[j] -= tmp[i-j+1][j]; if (r_sum[j] < 0) r_sum[j] += Mod; else if (r_sum[j] >= Mod) r_sum[j] -= Mod; } for (j = 2; j <= M; j++) sum += tmp[i][j]; sum = (sum + Mod - dp[i] % Mod) % Mod; } printf("%lld\n", ans % Mod); fflush(stdout); return 0; }