def solve_problem(n, k): MOD = 998244353 if k > n: return 0 result = 1 k = min(k, n - k) inv = [0] * (k + 1) if k >= 1: inv[1] = 1 for i in range(2, k + 1): inv[i] = MOD - (MOD // i) * inv[MOD % i] % MOD for i in range(1, k + 1): result = (result * (n - i + 1)) % MOD result = (result * inv[i]) % MOD return result print(solve_problem(*map(int, input().split())))