N = int(input()) M = int(input()) if N < M: print(0) exit() if N == M: print(1) exit() mod = 998244353 N %= mod base = pow(2, N, mod) val = 0 fact = 1 inv = 1 for i in range(M): val += fact * pow(inv, mod - 2, mod) % mod val %= mod fact = fact * (N - i) % mod inv = inv * (i + 1) % mod print((base - val) % mod)