mod = 998244353 N,K = map(int,input().split()) dp = [0]*(N+3) dp[0] = 1 dp[1] = -1 for i in range(N+2): if i != 0: dp[i] += dp[i-1] dp[i] %= mod if i+K <= N: dp[i+K] += dp[i] dp[i+K] %= mod dp[N+1] -= dp[i] dp[N+1] %= mod print(sum(dp)%mod) #print(dp)