N, K = map(int, input().split())
MOD = 998244353
DP = [1] + [0] * N
C = [0] * (N + 2)
for i in range(N + 1):
	Lt = i - K
	if Lt >= 0:
		DP[i] = C[Lt + 1] - C[0]
		DP[i] %= MOD
	C[i + 1] = C[i] + DP[i]
	C[i + 1] %= MOD
print(sum(DP) % MOD)