N,L=map(int,input().split()) n=(N+L-1)//L m=998244353 print((pow(2,n,m)-1)%m)