mod = 998244353 T = int(input()) for i in range(T): N,K = map(int,input().split()) print(pow(2,K*(N-1),mod)*(pow(2,K,mod)-1)*N % mod)