N,K=map(int,input().split()) u=[1]*(N+1) mod=998244353 for i in range(1,N+1): u[i]=u[i-1]*i u[i]%=mod u2=[1]*(N+1) for i in range(N+1): u2[i]=pow(u[i],-1,mod) def ncm(x,y): ans=u[x]*u2[y] ans%=mod ans*=u2[x-y] ans%=mod return ans result=1 for k in range(1,N+1): rest=N-k*K if rest<0: break result+=ncm(rest+k,k) result%=mod print(result)