MOD = 998244353 def main(): import sys N, K = map(int, sys.stdin.readline().split()) if K > N: print(1 % MOD) return max_s = N dp = [0] * (max_s + 2) sum_dp = [0] * (max_s + 2) # sum_dp[s] stores sum from s to max_s for s in range(max_s, -1, -1): if s + K > N: dp[s] = 1 else: a = s + K sum_val = sum_dp[a] dp[s] = (1 + sum_val) % MOD sum_dp[s] = (dp[s] + sum_dp[s + 1]) % MOD print(dp[0] % MOD) if __name__ == "__main__": main()