def calc(n, l): md = 998244353 ex = (n + (l-1)) // l ans = (2 ** ex -1) % md return ans if __name__ == '__main__': n, l = map(int, input().split()) print(calc(n, l))