n,k=map(int,input().split()) M=998244353 fa=[1,1] fb=[1,1] for i in range(2,n+1): fa+=[fa[-1]*i%M] fb+=[fb[-1]*(M//i)*fb[M%i]*fa[M%i-1]*(-1)%M] g=1 for i in range(1,n+1): if n-i*k>=0: g+=fa[n-i*k+i]*fb[n-i*k]*fb[i] g%=M print(g)