n = int(input()) m = int(input()) mod = 998244353 if n < m: print(0) exit() ans = pow(2, n, mod) - 1 res = 1 for i in range(1, m): res *= n - i + 1 res *= pow(i, mod - 2, mod) res %= mod ans -= res ans %= mod print(ans)