t = int(input()) mod = 998244353 def power(a, n, mod): bi=str(format(n,"b")) res=1 for i in range(len(bi)): res=(res*res) %mod if bi[i]=="1": res=(res*a) %mod return res for i in range(t): n,k = map(int, input().split()) ans = power(power(2, n, mod), k, mod)-power(power(2, n-1, mod), k, mod) ans *= n ans %= mod print(ans)