mod = 998244353 n, k = map(int, input().split()) dp = [0] *(n+1) dp[0] = 1 for i in range(1, n+1): dp[i] = dp[i-1] if i - k >= 0: dp[i] += dp[i-k] dp[i] %= mod print(dp[n])