T = int(input()) MOD = 998244353 for _ in range(T): N, K = map(int, input().split()) print(N*(pow(2, K, MOD)-1)*pow(pow(2, N-1, MOD), K, MOD)%MOD)