class BIT : def __init__(self, n) : self.size = n self.bit = [0] * (n + 1) def add(self, i, x) : while i <= self.size : self.bit[i] += x i += i & -i def sum(self, i) : s = 0 while i > 0 : s += self.bit[i] i -= i & -i return s def range_sum(self, l, r) : return self.sum(r) - self.sum(l-1) def main(): N,K = map(int,input().split()) if K > N: print(1) exit() ft = BIT(N+1) ft.add(0+1,1) mod = 998244353 for i in range(K,N+1): tmp = ft.sum(i-K+1)%mod ft.add(i+1,tmp) print(ft.sum(N+1)%mod) if __name__ == '__main__': main()