import java.util.Scanner; public class Main { public static void main(String[] args) throws Exception { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int m = sc.nextInt(); sc.close(); int mod = 998244353; long v1 = 1; long n1 = power(2, n, mod) - 1; for (int i = 1; i <= m; i++) { long ni = n1 - i + 1; if (ni < 0) ni += mod; v1 *= ni; v1 %= mod; long mi = modinv(i, mod); v1 *= mi; v1 %= mod; } long v2 = 1; long n2 = power(2, n - 1, mod) - 1; for (int i = 1; i < m; i++) { long ni = n2 - i + 1; if (ni < 0) ni += mod; v2 *= ni; v2 %= mod; long mi = modinv(i, mod); v2 *= mi; v2 %= mod; } v2 *= n1; v2 %= mod; long ans = (v1 - v2) % mod; if (ans < 0) ans += mod; System.out.println(ans); } static long power(long x, long n, int m) { if (n == 0) { return 1; } long val = power(x, n / 2, m); val = val * val % m; if (n % 2 == 1) { x %= m; val = val * x % m; } return val; } static long modinv(long a, int m) { long b = m; long u = 1; long v = 0; long tmp = 0; while (b > 0) { long t = a / b; a -= t * b; tmp = a; a = b; b = tmp; u -= t * v; tmp = u; u = v; v = tmp; } u %= m; if (u < 0) u += m; return u; } }