mod=998244353 for _ in range(int(input())): N,K=map(int,input().split()) res=N*(pow(2,N*K,mod)-pow(2,(N-1)*K,mod)) res%=mod print(res)