n,k = map(int,input().split()) a = [1]+[0]*n; b = [1]+[0]*n for i in range(1,n+1): if i>=k: a[i] = b[i-k] b[i] = (b[i-1]+a[i])%998244353 print(b[n])