M=998244353 L=10**8 p=[1,808258749,117153405,761699708,573994984,62402409,511621808,242726978,887890124,875880304,0] def f(x): a=p[x//L] for i in range((x//L)*L+1,x+1): a*=i a%=M return a n,k=map(int,input().split()) k=min(k,n-k) if k<=1755647: a=1 for i in range(k): a*=(n-i)*pow(i+1,M-2,M) a%=M print(a) exit() print(f(n)*pow(f(k)*f(n-k),M-2,M)%M)