N, K = map(int, input().split()) mod = 998244353 a= 1 for i in range(K): a = a * (N-i)//(i+1) print(a% mod)