t = int(input()) mod = 998244353 for _ in range(t): n, k = map(int, input().split()) print((n * pow(2, n * k, mod) * (1 - pow(pow(2, k, mod), mod - 2, mod))) % mod)