def main(): from sys import stdin, setrecursionlimit # setrecursionlimit(1000000) input = stdin.readline def iinput(): return int(input()) def sinput(): return input().rstrip() def i0input(): return int(input()) - 1 def linput(): return list(input().split()) def liinput(): return list(map(int, input().split())) def miinput(): return map(int, input().split()) def li0input(): return list(map(lambda x: int(x) - 1, input().split())) def mi0input(): return map(lambda x: int(x) - 1, input().split()) INF = 1000000000000000000 MOD = 998244353 N = iinput() M = iinput() if M > N: print(0) return def modinv(a): b = MOD u, v = 1, 0 while b > 0: t = a // b a -= t * b a, b = b, a u -= t * v u, v = v, u return u % MOD ans = pow(2, N % (MOD - 1), MOD) tmp = 1 for k in range(M): ans -= tmp ans %= MOD tmp *= N - k tmp %= MOD tmp *= modinv(k + 1) tmp %= MOD print(ans) main()