n,k = map(int,input().split())
a = [1]+[0]*n; b = [1]+[0]*n
for i in range(1,n+1):
    if i>=k: a[i] = b[i-k]
    b[i] = (b[i-1]+a[i])%998244353
print(b[n])