N, K = map(int, input().split()) dp = [0]*(N+1) imos = [0]*(N+1) MOD = 998244353 dp[0] = 1 for i in range(N): if i > 0: imos[i] += imos[i-1] imos[i] %= MOD dp[i] = imos[i] if i+K <= N: imos[i+K] += dp[i] imos[i+K] %= MOD imos[-1] += imos[-2] imos[-1] %= MOD dp[-1] = imos[-1] print(sum(dp)%MOD)