from functools import lru_cache n = int(input()) m = int(input()) @lru_cache def nPr(n,r): npr=1 for i in range(n,n-r,-1): npr*=i return npr sum = 0 if n-m >= 0: for i in range(n-m+1): sum += nPr(n,i) print(sum % 998244353)