def main(): h, w = map(int, input().split()) mod = 998244353 rep = 3 rep += (h-2) + (w-2) + (h-2)*(w-2) return pow(2, rep, mod) print(main())