mod = 998244353 t = int(input()) for _ in range(t): n, m = list(map(int, input().split())) print((pow(pow(2, n, mod)-1, m, mod) - (pow(2, n*m, mod)-pow(pow(2, m, mod)-1, n, mod)))%mod)