MOD=998244353 N,M=map(int,input().split()) ans=0 for i in range(min(N,M+1)+1): dp=[1] for j in range(N): ndp=[0]*(j+2) for k in range(j+1): dp[k]%=MOD ndp[k]+=dp[k]*(M-i+k) ndp[k+1]+=dp[k]*(i-k) dp=ndp ans+=dp[i]*i*i ans%=MOD print(ans)