N,K = map(int,input().split()) dp = [0] * (N + 1) dp[0] = 1 S = [0] * (N + 1) S[0] = 1 DIV = 998244353 for i in range(1, N + 1): if i - K < 0: dp[i] = 0 else: dp[i] += S[i - K] dp[i] %= DIV S[i] += S[i - 1] + dp[i] S[i] %= DIV print(S[-1])