mod = 998244353 t = int(input()) for _ in range(t): n, k = map(int, input().split()) ans = pow(pow(2, n, mod), k, mod) - pow(pow(2, n-1, mod), k, mod) ans = ans * n % mod print(ans)