import sys sys.setrecursionlimit(10 ** 8) from functools import lru_cache n,k=map(int,input().split()) mod=998244353 @lru_cache(maxsize=None) def f(x): stc=1 for i in range(0,x-k+1): stc+=f(i) return stc%mod print(f(n))