#面白い! mod = 998244353 def solve(): N, M = map(int, input().split()) # ans = (2 ** N - 1) ** M + (2 ** M - 1) ** N - (2 ** M) ** N ans = pow(pow(2, N, mod) - 1, M, mod) + pow(pow(2, M, mod) - 1, N, mod) - pow(pow(2, M, mod), N, mod) print(ans % mod) for _ in range(int(input())): solve()