N,L=map(int,input().split()) MOD=998244353 ans=pow(2,N//L+min(1,N%L),MOD)-1 print(ans+MOD if ans<0 else ans)