N,M=map(int,input().split()) Mod=998244353 two_inv=pow(2,Mod-2,Mod) X=N*pow(M,N+1,Mod)*(M+1)*two_inv X%=Mod for l in range(1,M+1): X-=(N*two_inv*pow(l-1,N,Mod)*l)%Mod Y=0 for l in range(1,M+1): Y+=pow(M-l+1,N,Mod)*(M+l) Y%=Mod Y*=N*two_inv Y%=Mod print((X-Y)%Mod)