l=input() m=input() s=int(l/m) if l%m!=0: s=s+1 t=pow(2,s)-1 t=t%998244353 print(t)