import sys input = sys.stdin.readline mod=998244353 t=int(input()) for tests in range(t): N,M=list(map(int,input().split())) ALL=pow(pow(2,N,mod)-1,M,mod) #print(ALL) x=pow(2,M,mod)-1 minus=pow(1+x,N,mod) ANS=ALL-minus+pow(x,N,mod) print(ANS%mod)