#include<bits/stdc++.h>
#include<atcoder/all>
using mint=atcoder::modint998244353;
using namespace std;
int main(){
	int n,p;
	scanf("%d%d",&n,&p);
	vector<mint> fac(n+1);
	fac[0]=1;
	for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i;
	mint ans=fac[n];
	for(int i=0;i<=n/p;i++){
		ans-=fac[n]/mint(p).pow(i)/fac[i]/fac[n-p*i];
	}
	printf("%d\n",ans.val());
}