t = int(input()) mod = 998244353 for _ in range(t): n, k = map(int, input().split()) ans = pow(pow(2, n, mod), k, mod) - pow(pow(2, n - 1, mod), k, mod) ans = ans * n % mod print(ans)