mod = 998244353 def test(): N, K = map(int, input().split()) M = pow(2, K, mod) - 1 res = M * N * pow(M + 1, N - 1, mod) % mod return res T = int(input()) for _ in range(T): print(test())