N, K = map(int, input().split()) MOD = 998244353 DP = [1] + [0] * N C = [0] * (N + 2) for i in range(N + 1): Lt = i - K if Lt >= 0: DP[i] = C[Lt + 1] - C[0] DP[i] %= MOD C[i + 1] = C[i] + DP[i] C[i + 1] %= MOD print(sum(DP) % MOD)