N,K = map(int,input().split()) dp = [0] * (N + 2) ans = 0 P = 998244353 dp[0] = 1 sub = [0] * (N + 2) for i in range(N + 1): dp[i] += dub[i] dp[i] %= P ans += dp[i] ans %= P if i + K <= N: sub[i + K] += dp[i] sub[i + K] %= P sub[i + 1] += sub[i] sub[i + 1] %= P print(ans)