P = 998244353 N, L = map(int, input().split()) print((pow(2, (N - 1) // L + 1, P) - 1) % P)