n,l=map(int,input().split()) if n<=l: print("1") exit() def pow_r(x, n): if n == 0: # exit case return 1 if n % 2 == 0: # standard case ① n is even return pow_r(x ** 2, n // 2)%998244353 else: # standard case ② n is odd return (x * pow_r(x ** 2, (n - 1) // 2))%998244353 ans=pow_r(2,n-l+1)%998244353 if ans>=1: print(ans-1) else: print("998244352")