# n, m = map(int, input().split()) # ans = 0 # for bits in range(1 << (n*m)): # mat = [[0]*m for _ in range(n)] # for i in range(n): # for j in range(m): # if (bits >> (i*m+j)) & 1: # mat[i][j] = 1 # f = True # for i in range(n-1): # for j in range(m-1): # sm1 = 0 # sm2 = 0 # # (0,0)→(i,j)の和 # for y in range(i+1): # for x in range(j+1): # sm1 += mat[y][x] # # (i+1, j+1)→(n-1,m-1)の和 # for y in range(i+1, n): # for x in range(j+1, m): # sm2 += mat[y][x] # if sm1 % 2 == sm2 % 2: # f = False # if f: # ans += 1 # print(ans) # ×2 # 1 1→2 # 1 2→4 # 1 3→8 # 1 4→16 # 1 5→32 # ×2 # 2 2→8 # 2 3→16 # 2 4→32 # 2 5→64 # ×4 # 3 3→64 # 3 4→256 # 3 5→1024 # ×8 # 4 4→2048 # 4 5→16384 # n m: n, m = m, n # (1, m)まで行く. ans = pow(2, m, 998244353) if n == 1: print(ans) exit() if n == 2: print(ans*2 % 998244353) exit() ans *= 2 ans %= 998244353 ans *= pow(pow(2, m-1, 998244353), n-2, 998244353) ans %= 998244353 print(ans)