def solve_problem(n, k): MOD = 998244353 if k > n: return 0 result = 1 k = min(k, n - k) for i in range(1, k + 1): result = (result * (n - i + 1)) % MOD result = (result * pow(i, MOD - 2, MOD)) % MOD return result print(*solve_problem(map(int, input().split())))