from sys import stdin input = stdin.readline MOD = 998244353 def inverse(n, d): return n * pow(d, -1, MOD) % MOD N, M = map(int, input().split()) top = 1 s = (pow(2, N-1, MOD)-1)%MOD bottom = 1 for i in range(1, M): top *= s-i+1 top %= MOD bottom *= i bottom %= MOD top2 = 1 s2 = (pow(2, N, MOD)-1)%MOD bottom2 = 1 for i in range(1, M+1): top2 *= s2-i+1 top2 %= MOD bottom2 *= i bottom2 %= MOD print((inverse(top2, bottom2) - inverse(top, bottom)*(pow(2, N, MOD)-1)%MOD)%MOD)