import sys input=lambda: sys.stdin.readline().rstrip() n,l=map(int,input().split()) mod=998244353 m=(n-1)//l+1 print((pow(2,m,mod)-1)%mod)