n,k=map(int,input().split()) mod=998244353 framod=[1] def framod_calc(n, mod, a=1): for i in range(1,n+1): a=a * i % mod framod.append(a) framod_calc(k+1, mod) def combmod(n, k, mod): a=framod[n] b=framod[k] c=framod[n-k] return (a * pow(b, mod-2, mod) * pow(c, mod-2, mod)) % mod ans=0 for i in range(1,k+1): tmp=combmod(k,i,mod)*pow(k-i,n,mod) if i%2==1: ans+=tmp*((2**(i-1))%mod) ans%=mod else: ans-=tmp*((2**(i-1))%mod) ans=(ans+mod)%mod print(ans)