N, K = map(int, input().split()); N -= K if N < K: print(1) exit() mod = 998244353 dp = [0]*(N+1); dp[0] = 1 for i in range(N+1): if i != 0: dp[i] += dp[i-1]; dp[i] %= mod if i + K <= N: dp[i+K] += dp[i] ans = 1 for i in range(N+1): ans += dp[i]; ans %= mod print(ans)