n, l = map(int, input().split()) n = (n + l - 1) // l mod = 998244353 print((pow(2, n, mod) - 1) % mod)