H, W = map(int, input().split()) if H == 1: print(1) exit() mod = 998244353 M = max(H, W) + 1 fact = [1] * M inv = [1] * M for i in range(1, M): fact[i] = fact[i - 1] * i % mod inv[i] = pow(fact[i], mod - 2, mod) def comb(n, r): if not 0 <= r <= n: return 0 return fact[n] * inv[r] % mod * inv[n - r] % mod ans = 0 t = 1 k = 1 while t: if H*W-(H+W-1+k-1) < 0: break t = pow(2, H*W-(H+W-1+k-1), mod) * comb(W-1, k//2) * comb(H-2, (k-1)//2) if not t: break ans += t ans %= mod k += 1 print(ans)