n,l = map(int,input().split()) p = (n + l - 1) // l mod = 998244353 res = pow(2, p, mod) if res == 0: res = mod - 1 else: res -= 1 print(res)