MOD=998244353 N,M=map(int,input().split()) ans=0 for i in range(min(N,M+1)+1): dp=[1] flg=0 if i==M+1: flg=1 for j in range(N): ndp=[0]*(j+2) for k in range(j+1): dp[k]%=MOD ndp[k]+=dp[k]*(M-i+k+flg) ndp[k+1]+=dp[k]*(i-k) dp=ndp ans+=dp[i]*i*i ans%=MOD print(ans)