class CP: def __init__(self, N): self.fact = [1] self.fact_inv = [1] for i in range(1, N+1): self.fact.append((self.fact[-1]*i)%MOD) self.fact_inv.append(pow(self.fact[-1], -1, MOD)) def C(self, N, K): if N >= K: return self.fact[N]*self.fact_inv[K]%MOD*self.fact_inv[N-K]%MOD else: return 0 def P(self, N, K): if N >= K: return self.fact[N]*self.fact_inv[N-K]%MOD else: return 0 N, P = map(int, input().split()) MOD = 998244353 F = [1] for i in range(1, N+1): F.append(F[-1]*i%MOD) cp = CP(N) B = [1, 1] for i in range(2, N//P+1): B.append(B[-1]*cp.C(i*P-1, P-1)%MOD) ans = F[N] for i in range(N//P+1): SUM = cp.C(N, P*i) if 2 <= i: SUM *= B[i] SUM %= MOD if 1 <= i: SUM *= pow(F[P-1], i, MOD) SUM %= MOD ans -= SUM ans %= MOD print(ans)