N,K = map(int,input().split()) MOD = 998244353 dp = [0]*(N+1) dp[0] = 1 if N<2*K: for i in range(K,N+1): dp[i] = 1 elif N>=2*K: for i in range(K,2*K): dp[i] = 1 cum = dp[K] for i in range(2*K,N+1): dp[i] = cum+1 cum = (cum+dp[i-K+1])%MOD ans = 0 for i in range(N+1): ans = (ans+dp[i])%MOD print(ans)