n, L = map(int, input().split()) mod = 998244353 k = (n + L - 1) // L print((pow(2, k, mod) - 1) % mod)