N, K = map(int, input().split())
MOD = 998244353

sum_list = [0] * (N + 1)
dp = [0] * (N + 1)
sum_list[1] = 1
dp[0] = 1
for i in range(K, N + 1):
    if i - K - 1 >= 0:
        sum_list[i - K] = dp[i - K] + sum_list[i - K - 1]

    dp[i] = sum_list[i - K] + 1
print(sum(dp) % MOD)