N,K=map(int,input().split()) dp=[0]*(N+1) mod=998244353 dp[0]=1 for i in range(1,N+1): dp[i]=(dp[i-1]+dp[i-K])%mod ans=dp[N] print(ans)