n,k=map(int,input().split()) M=998244353 dp=[1]+[0]*n for i in range(1,n+1): if i-k>=0: dp[i]+=dp[i-k] dp[i]+=dp[i-1] dp[i]%=M print(dp[n])