N,K=map(int,input().split()) dp=[0]*(N+1) mod=998244353 dp[0]=1 for i in range(1,N+1): dp[i]=dp[i-1] if i-K>=0: dp[i]+=dp[i-K] dp[i]%=mod ans=dp[N] print(ans)