## https://yukicoder.me/problems/no/1887 MOD = 998_244_353 def main(): N, M = map(int, input().split()) cum_dp = [[0] * (M + 1) for _ in range(N + 1)] total_cum_dp = [0 for _ in range(N + 1)] dp = [[0] * (M + 1) for _ in range(N + 1)] dp[0][0] = 1 cum_dp[0][0] = 1 total_cum_dp[0] = 1 for i in range(N): for m in range(2, M + 1): a = (total_cum_dp[i] - total_cum_dp[i - (m - 1)]) if i - (m - 1) >= 0 else total_cum_dp[i] a %= MOD b = (cum_dp[i][m] - cum_dp[i - (m - 1)][m]) if i - (m - 1) >= 0 else cum_dp[i][m] b %= MOD c = (a - b) % MOD dp[i + 1][m] = c total_cum_dp[i + 1] = total_cum_dp[i] for m in range(M + 1): cum_dp[i + 1][m] = (cum_dp[i][m] + dp[i + 1][m]) % MOD total_cum_dp[i + 1] += dp[i + 1][m] total_cum_dp[i + 1] %= MOD ans = 0 for m in range(M + 1): ans += dp[N][m] ans %= MOD answer = pow(M, N, MOD) answer -= ans answer %= MOD print(answer) if __name__ == "__main__": main()