N,P=map(int,input().split()) mod=998244353 FACT=[1] for i in range(1,4*10**5+1): FACT.append(FACT[-1]*i%mod) FACT_INV=[pow(FACT[-1],mod-2,mod)] for i in range(4*10**5,0,-1): FACT_INV.append(FACT_INV[-1]*i%mod) FACT_INV.reverse() def Combi(a,b): if 0<=b<=a: return FACT[a]*FACT_INV[b]%mod*FACT_INV[a-b]%mod else: return 0 ANS=0 x=1 if P<=N: for i in range(0,3*10**5): if i*P>N: break now=i*P if now-P>=0: x=x*Combi(now,now-P)%mod ANS+=x*FACT_INV[i]%mod*pow(FACT[P-1],i,mod)%mod*Combi(N,i*P)%mod ANS%=mod else: ANS=1 print((FACT[N]-ANS)%mod)