n,l = map(int,input().split()) mod = 998244353 x = -(-n//l) ans = pow(2,x,mod)-1 ans %= mod print(ans)