n, L = map(int, input().split())
mod = 998244353
k = (n + L - 1) // L
print((pow(2, k, mod) - 1) % mod)