N, L = map(int, input().split()) M = (N + 1) // L print((pow(2, M, 998244353) - 1) % 998244353)