N,M=map(int,input().split()) result=0 mod=998244353 for x in range(1,M+1): if x==M: result+=pow(M,N,mod)*N result%=mod continue p=M*pow(x,-1,mod) p%=mod w=pow(x,N,mod) w*=1-pow(p,N,mod) w%=mod w*=pow(1-p,-1,mod) w%=mod result+=w result%=mod print(result)