import sys input = sys.stdin.readline MOD = 998244353 N, K = map(int, input().split()) dp = [0 for _ in range(N+1)] if K<=N: dp[K] = 1 else: print(1) exit() s = 0 for i in range(N+1): s += dp[i] if s==0: continue if i+K<=N: dp[i+K] += s dp[i+K] %= MOD #print(dp) ans = 1 s = 0 for i in range(N+1): s += dp[i] s %= MOD ans += s ans %= MOD print(ans)