def main(): n, m = list(map(int, input().split())) return pow(2, n*m-1, 998244353) print(main())