T = int(input()) m = 998244353 for _ in range(T) : a, b = map(int, input().split()) print((pow(pow(2, a, m)-1, b, m) - (pow(2, a*b, m) - pow(pow(2, b, m) - 1, a, m))) % m)