n,p = map(int,input().split()) mod = 998244353; fa = [1]*(n+1); fi = [1]*(n+1) for i in range(1,n): fa[i+1] = fa[i]*(i+1)%mod fi[n] = pow(fa[n],mod-2,mod) for i in range(n,0,-1): fi[i-1] = fi[i]*i%mod def cmb(n,r): return fa[n]*fi[n-r]%mod*fi[r]%mod if 0<=r<=n else 0 x = 1; ans = fa[n]-1 for i in range(n//p): x = x*fa[p-1]%mod*cmb(n-i*p,p)%mod; ans -= x*fi[i+1]%mod print(ans%mod)