MOD = 998244353 def main(): import sys N, M = map(int, sys.stdin.readline().split()) # Precompute combinations up to 2000 max_k = min(N, M + 1) comb = [[0] * (2001) for _ in range(2001)] comb[0][0] = 1 for k in range(1, 2001): comb[k][0] = 1 comb[k][k] = 1 for s in range(1, k): comb[k][s] = (comb[k-1][s-1] + comb[k-1][s]) % MOD total = 0 k_max = min(N, M + 1) for k in range(1, k_max + 1): if (k - 1) > M: continue if M >= k: b = (1 + 2 * (M - k)) % MOD else: b = 0 sum_s = 0 for s in range(0, k + 1): c = comb[k][s] base = (k + b + s) % MOD power = pow(base, N, MOD) # Compute (-1)^(k - s) exponent = k - s if exponent % 2 == 0: sign = 1 else: sign = MOD - 1 term = (c * power) % MOD term = (term * sign) % MOD sum_s = (sum_s + term) % MOD contribution = (k * sum_s) % MOD total = (total + contribution) % MOD print(total) if __name__ == "__main__": main()