#include #include using mint=atcoder::modint998244353; using namespace std; int main(){ int n,p; scanf("%d%d",&n,&p); vector fac(n+1); fac[0]=1; for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i; mint ans=fac[n]; for(int i=0;i<=n/p;i++){ ans-=fac[n]/mint(p).pow(i)/fac[i]/fac[n-p*i]; } printf("%d\n",ans.val()); }