import sys input = sys.stdin.readline sys.set_int_max_str_digits(0) from collections import defaultdict, deque, Counter from heapq import heappop, heappush from bisect import bisect_left, bisect_right ## gcd(x, y):最大公約数, lcm(x, y):最小公倍数, factorial(n):階乗n!, prem(n, k):nPk(n, k), comb(n, r):二項係数nCr from math import gcd, lcm, factorial, perm, comb #0~9を並び替えるならpermutationsかconbinations,N列のカテゴリを作るにはproduct from itertools import product, permutations, combinations, accumulate from functools import lru_cache #@lru_cache(maxsize=128) import operator from string import ascii_uppercase, ascii_lowercase, digits # 英字(大文字), 英字(小文字), 数字 MOD = 998244353 def II():return int(input()) def LI():return list(input()) def LMI():return list(map(int, input().split())) def LMS():return list(map(str, input().split())) def LLMI(x):return [list(map(int, input().split())) for _ in range(x)] def LLMS(x):return [list(map(str, input().split())) for _ in range(x)] def CUM(x:list) -> list: ''' func:累積の仕方を指定する。 operator.mul:掛け算 operator.sub:引き算 max:最大値 min:最小値 initial:初期値, Noneならx[0]が第一引数の数値になる ''' return list(accumulate(x, func=None, initial=0)) def yn(tf:bool): if tf: return print('YES') else: return print('No') class UnionFind(): def __init__(self, n): self.n = n self.parents = [-1] * n def find(self, x): if self.parents[x] < 0: return x else: self.parents[x] = self.find(self.parents[x]) return self.parents[x] def union(self, x, y): x = self.find(x) y = self.find(y) if x == y: return if self.parents[x] > self.parents[y]: x, y = y, x self.parents[x] += self.parents[y] self.parents[y] = x def size(self, x): return -self.parents[self.find(x)] def same(self, x, y): return self.find(x) == self.find(y) def members(self, x): root = self.find(x) return [i for i in range(self.n) if self.find(i) == root] def roots(self): return [i for i, x in enumerate(self.parents) if x < 0] def group_count(self): return len(self.roots()) def group(self): group_members = defaultdict(list) for member in range(self.n): group_members[self.find(member)].append(member) return group_members def __str__(self): return ''.join(f'{r}: {m}' for r, m in self.group().items()) def inverse_element(num:int): ''' 逆元の作成 ax ≡ 1 (mod p)となるxは、fetmatの小定理より a * a^(p-2) ≡ 1 (mod p)であるため、 a^(p-2) (mod p) は逆元である ''' return pow(num, MOD-2, MOD) def make_graph(n:int, lmi:list, idx_0:bool): graph = [[] for _ in range(n)] for i in range(len(lmi)): a, b = lmi[i] if idx_0: a -= 1 b -= 1 # 有向グラフであれば1方向にappendする。 graph[a].append(b) graph[b].append(a) return graph def dfs(n:int, graph:list[list[int]], s:int = 0, g:int = None): ''' s:start地点、指定しなければ頂点0から g:goal地点、指定しなければ端まで ''' d = deque([(s, 0)]) TF = [False] * n TF[s] = True while d: crr, cnt = d.popleft() print(crr, cnt) if g is not None and crr == g: ## gにたどり着けるか return True for nxt in graph[crr]: if TF[nxt]:continue d.append((nxt, cnt+1)) TF[nxt] = True else: return False def dijkstra(n:int, graph:list[list[int, int]], s:int = 0): ''' s:start地点、指定しなければ頂点0から ''' que = [] heappush(que, (0, s)) TF = [False] * n # 各頂点の最短経路を格納する ans = [0] * n while que: cnt, crr = heappop(que) if TF[crr]: continue # 最短経路確定 TF[crr] = True ans[crr] = cnt for nxt, val in graph[crr]: # 最短経路が確定しているところは除く if TF[nxt]:continue heappush(que, (cnt+val, nxt)) else: return ans def prime_factorize(n): a = [] while n % 2 == 0: a.append(2) n //= 2 f = 3 while f * f <= n: if n % f == 0: a.append(f) n //= f else: f += 2 if n != 1: a.append(n) return a def execute(): n, k =LMI() dp = [0]*(n+1) dp[0] = 1 for i in range(1,n+1): dp[i] = dp[i-1] dp[i] %= MOD if i-k>=0: dp[i] += dp[i-k] dp[i] %= MOD print(dp[n]) if __name__ == "__main__": T = 1 for _ in range(T): execute()