N,M=map(int,input().split()) Mod=998244353 X=pow(M,N,Mod)*(M-1) Y=0 for l in range(1,M): Y+=pow(l,N,Mod) two_inv=pow(2,Mod-2,Mod) Z=N*(M+1)*(X-2*Y)*two_inv%Mod print(Z)