MOD = 998244353 N, K = map(int, input().split()) dp = [0] * (N + 1) cs = [0] * (N + 1) dp[0] = 1 cs[0] = 1 for i in range(1, N + 1): if i - K >= 0: dp[i] = cs[i - K] cs[i] = (cs[i - 1] + dp[i]) % MOD print(cs[-1])