import sys input=lambda:sys.stdin.readline().rstrip() mod=998244353 N,L=map(int,input().split()) time=(N+L-1)//L ans=1 temp=2 while time: if time%2: ans=(ans*temp)%mod temp=(temp*temp)%mod time//=2 print((ans+mod-1)%mod)