def solve(): N, M = map(int, input().split()) MOD = 998244353 for x in range(N + 1): dp = [[0] * (x + 1) for _ in range(N + 1)] dp[0][0] = 1 for i in range(N): for j in range(x + 1): if dp[i][j] == 0: continue max_val = min(i + M, x) if j < x: ways = max(0, max_val - j) dp[i + 1][j + 1] = (dp[i + 1][j + 1] + dp[i][j] * ways) % MOD ways = max(0, max_val - j + 1) dp[i + 1][j] = (dp[i + 1][j] + dp[i][j] * ways) % MOD ans = dp[N][x] print(ans) if __name__ == "__main__": solve()