N, K = map(int, input().split()) MOD = 998244353 dp = [0] * (N + 1) dp[0] = 1 for i in range(1, N + 1): for j in range(0, i - K + 1): dp[i] += dp[j] % MOD print(sum(dp) % MOD)