n, m = map(int, input().split()) mod = 998244353 ans = n for i in range(2*n+1): ans = ans*(2*n+m-i)%mod*pow(i+1, mod-2, mod)%mod print(ans)