MOD=998244353 N,K=map(int,input().split()) dp=[0 if i else 1 for i in range(K)] S=[1 for i in range(K)] while len(S)<=N: dp.append(S[len(dp)-K]) S.append((S[-1]+dp[-1])%MOD) print(S[-1])