N,K=map(int,input().split()) mod=998244353 ANS=0 for i in range(1,K+1): ANS=(ANS+pow(K-i,N,mod)+(N*i%mod)*pow(K-i,N-1,mod))%mod print((K*pow(K,N,mod)-ANS)%mod)