T = int(input()) mod = 998244353 def doubling(n, m): y = 1 tmp = m bas = n while tmp: if tmp % 2: y *= bas y %= mod bas *= bas bas %= mod tmp >>= 1 return y for _ in range(T): N, K = map(int, input().split()) print(N*(doubling(2, N*K) - doubling(2, K*(N-1)))%mod)