n,l=map(int,input().split()) ans=1 for i in range(n-l+1): ans=(ans*2)%998244353 print(ans-1)