import sys input = sys.stdin.readline mod=998244353 N,K=map(int,input().split()) DP=[0]*(N+1) DP[0]=1 for i in range(N+1): DP[i]=(DP[i-1]+DP[i])%mod if i+K<=N: DP[i+K]+=DP[i] print(DP[-1]%mod)