import sys
sys.setrecursionlimit(10 ** 8)
from functools import lru_cache
n,k=map(int,input().split())
mod=998244353
@lru_cache(maxsize=None)
def f(x):
    stc=1
    for i in range(0,x-k+1):
        stc+=f(i)
    return stc%mod
print(f(n))