from sys import stdin def input(): return stdin.readline()[:-1] MOD = 998244353 class FenwickTree: def __init__(self, n) -> None: self.n = n self.bit = [0] * (n) def sum(self, r): ret = 0 while r >= 0: ret += self.bit[r] ret %= MOD r = (r & (r + 1)) - 1 return ret def add(self, idx, delta): while idx < self.n: self.bit[idx] += delta self.bit[idx] %= MOD idx |= idx + 1 def solve(): n, k = map(int, input().split()) tree = FenwickTree(n + 1) tree.add(0, 1) for i in range(1, n + 1): tree.add(i, tree.sum(i - k)) print(tree.sum(n)) solve()