N,M=map(int,input().split())
Mod=998244353

two_inv=pow(2,Mod-2,Mod)

X=N*pow(M,N+1,Mod)*(M+1)*two_inv
X%=Mod

for l in range(1,M+1):
    X-=(N*two_inv*pow(l-1,N,Mod)*l)%Mod

Y=0
for l in range(1,M+1):
    Y+=pow(M-l+1,N,Mod)*(M+l)
    Y%=Mod

Y*=N*two_inv
Y%=Mod

print((X-Y)%Mod)