import sys sys.setrecursionlimit(10**7) from functools import lru_cache MOD = 998244353 def main(): N, K = map(int, sys.stdin.readline().split()) @lru_cache(maxsize=None) def f(i: int) -> int: if i == 0: return 1 if i < K: return 0 return S(i - K) % MOD # 漸化式 f(i) = S(i-K) @lru_cache(maxsize=None) def S(x: int) -> int: if x < 0: return 0 return (S(x - 1) + f(x)) % MOD # 累積和 # 答えは S(N) print(S(N) % MOD) if __name__ == "__main__": main()