N,M=map(int,input().split()) mod=998244353 dp=[0]*(M+1) dp[0]=1 for m in range(1,M+1): dp[m]+=dp[m-1] if 1