import sys def MI(): return map(int, sys.stdin.readline().split()) md=998244353 n,l=MI() c=(n+l-1)//l print(pow(2,c,md)-1)