mod=998244353 t=int(input()) for _ in range(t): n,m=map(int,input().split()) tar=0 tar+=pow(pow(2,n,mod)-1,m,mod) tar-=pow(pow(2,m,mod),n,mod) tar+=pow(pow(2,m,mod)-1,n,mod) print(tar%mod)