import sys N,P = map(int,input().split()) mod = 998244353 if N == 1: print(0) exit() C = 2 * 10 ** 5 + 100 fact = [1] * C fact_inv = [1] * C for i in range(2,C): fact[i] = fact[i-1] * i % mod fact_inv[-1] = pow(fact[-1],mod - 2,mod) for i in range(C - 2,1,-1): fact_inv[i] = fact_inv[i + 1] * (i + 1) % mod ans = 1 k = 1 r = pow(P,mod - 2,mod) while k * P <= N: #ans += fact[N] * pow(fact_inv[P],k,mod) * fact[k] % mod ans += fact[N] * fact_inv[N - k * P] % mod * pow(r,k,mod) * fact_inv[k] % mod ans %= mod k += 1 ans = fact[N] - ans ans %= mod print(ans)