N,K=map(int,input().split()) Mod=998244353 T=[0]*(K+1); U=pow(K,N,Mod) for k in range(K,0,-1): """ 2nd_max(T)>=k iff not (全て k 未満 or ( k 以上がちょうど 1 個, 残りは k 未満)) """ T[k]=U-(pow(k-1,N,Mod)+N*(K-k+1)*pow(k-1,N-1,Mod))%Mod print(sum(T)%Mod)