#yuki1206 t=int(input()) mod=998244353 for i in range(t): n,k=map(int,input().split()) print((n*pow(2,n*k,mod)*(1-pow(pow(2,k,mod),mod-2,mod)))%mod)