def main(): mod = 998244353 N,M = map(int,input().split()) mx = 0 pown = [0] for i in range(1,M+1): temp = pow(i,N-1,mod) pown.append(temp) for i in range(1,M+1): #1-M #temp = (i*(i+1)//2)*pow(i,N-1,mod)*N - ((i-1)*i//2)*pow(i-1,N-1,mod)*N temp = (i*(i+1)//2)*pown[i]*N - ((i-1)*i//2)*pown[i-1]*N mx += temp*i mx %= mod mn = 0 for i in range(1,M+1): temp = ((M+i)*(M-i+1)//2)*pown[M-i+1]*N - ((M+i+1)*(M-i)//2)*pown[M-i]*N mn += temp*i mn %= mod ans = mx - mn print(ans%mod) if __name__ == '__main__': main()