n,l = map(int,input().split()) x = (n+l-1)//l print(pow(2,x,998244353)-1)