N, K = map(int, input().split()) MOD = 998244353 sum_list = [0] * (N + 1) dp = [0] * (N + 1) sum_list[1] = 1 dp[0] = 1 for i in range(K, N + 1): if i - K - 1 >= 0: sum_list[i - K] = dp[i - K] + sum_list[i - K - 1] dp[i] = sum_list[i - K] + 1 print(sum(dp) % MOD)