mod = 998244353 def solve(): n, m = map(int, input().split()) print((pow(pow(2, n, mod) - 1, m, mod) - (pow(pow(2, m, mod), n, mod) - pow(pow(2, m, mod) - 1, n, mod))) % mod) t = int(input()) for _ in range(t): solve()