#yuki1184 mod=998244353 n,l=map(int,input().split()) n=-(-n//l) res=(pow(2,n,mod)-1)%mod print(res)