import sys input = sys.stdin.readline N,M=list(map(int,input().split())) mod=998244353 ALL=pow(M,N,mod) ANS=ALL*N for i in range(2,M+1): #print(ANS) k=pow(M,mod-2,mod)*(M-i+1) #print(k) ANS=(ANS+ALL*k*(1-pow(k,N,mod))*pow(1-k,mod-2,mod))%mod print(ANS%mod)