def main(): N,K = map(int,input().split()) mod = 998244353 dp = [0 for _ in range(N+1)] dp[0] = 1 acum = [] for i in range(min(K,N+1)): acum.append(1) for i in range(K,N+1): dp[i] += acum[i-K] dp[i] = dp[i] % mod acum.append(acum[-1]+dp[i]) print(sum(dp)%mod) if __name__ == '__main__': main()