class BIT: def __init__(self, n): self.n = n self.data = [0] * (n + 1) def sum(self, i): s = 0 while i > 0: s = (s + self.data[i]) % MOD i -= i & -i return s def add(self, i, x): while i <= self.n: self.data[i] = (self.data[i] + x) % MOD i += i & -i MOD = 998244353 def cnt(N, K): bit = BIT(N + 1) bit.add(1, 1) for i in range(1, N + 1): val = (bit.sum(i) - bit.sum(max(0, i - K + 1)) + MOD) % MOD bit.add(i + 1, val) return bit.sum(N + 1) N, K = map(int, input().split()) result = cnt(N, K) print(result)