n,m=map(int,input().split()) mod=998244353 ans=pow(2,n,mod)-1 x=1 for i in range(m): x*=pow(i+1,mod-2,mod)*(pow(2,n,mod)-1-i) x%=mod for i in range(m-1): ans*=pow(i+1,mod-2,mod)*(pow(2,n-1,mod)-1-i) ans%=mod ans=(x-ans)%mod print(ans)