import sys input = sys.stdin.readline mod = 998244353 T=int(input()) for tests in range(T): N,K=map(int,input().split()) X=pow(2,-1,mod) print(N*pow(pow(2,N,mod),K,mod)*(1-pow(X,K,mod))%mod)