MOD = 998244353 def main(): import sys N, M = map(int, sys.stdin.readline().split()) max_x = min(N + 1, M + 1) # Precompute combinations C[x][s] C = [[0] * (2001) for _ in range(2001)] C[0][0] = 1 for x in range(1, 2001): C[x][0] = 1 for s in range(1, x+1): C[x][s] = (C[x-1][s-1] + C[x-1][s]) % MOD ans = 0 for x in range(1, max_x + 1): if x - 1 > M: continue res = 0 for s in range(0, x+1): sign = (-1) ** (x - s) comb = C[x][s] if x <= M: base = (s + 2 * M - x + 1) % MOD else: base = (s + x + 1) % MOD # Compute base^N mod MOD pow_val = pow(base, N, MOD) term = comb * pow_val term %= MOD if (x - s) % 2 == 1: term = (-term) % MOD res = (res + term) % MOD ans = (ans + x * res) % MOD print(ans % MOD) if __name__ == "__main__": main()