n, l = map(int, input().split()) if n <= l: print(1) exit() mod = 998244353 i = 0 if n % l == 0: i = (n - 1) // l else: i = n // l print((pow(2, i + 1, mod) - 1))