N,M=map(int,input().split()) mod=998244353 ANS=0 INV=pow(2,mod-2,mod) for s in range(M): ko=(pow(s+1,N,mod)-pow(s,N,mod)*2+pow(s-1,N,mod))*(M-s) ANS+=ko*s*((1+1+s)+(M+M-s))*N*INV*INV ANS%=mod print(ANS)