N,M=map(int,input().split()) Mod=998244353 X=pow(M,N+1,Mod)*(M+1) Y=0 for l in range(1,M+1): a=pow(l-1,N,Mod)*l%Mod b=pow(M-l+1,N,Mod)*(M+l)%Mod Y+=a+b two_inv=pow(2,Mod-2,Mod) Z=N*(X-Y)*two_inv%Mod print(Z)