T = int(input()) mod = 998244353 for _ in range(T): N, K = map(int, input().split()) print(N * (pow(2, K, mod) - 1) * pow(2, K * (N - 1), mod) % mod)