t=int(input()) mod=998244353 for i in range(t): n,k=map(int,input().split()) ans=pow(2,n*k,mod) ans-=pow(2,(n-1)*k,mod) ans%=mod print((ans*n)%mod)