n, k = map(int, input().split()) mod = 998244353 # dp[x] = レーティング x までの推移の数 dp = [0] * (n + 1) # prefix[x] = dp[0] + dp[1] + ... + dp[x] (累積和) prefix = [0] * (n + 1) dp[0] = 1 prefix[0] = 1 for x in range(1, n + 1): if x < k: dp[x] = 1 else: dp[x] = (1 + prefix[x - k]) % mod prefix[x] = (prefix[x - 1] + dp[x]) % mod print(dp[n])