import sys; input = sys.stdin.buffer.readline sys.setrecursionlimit(10**7) from collections import defaultdict con = 998244353; INF = float("inf") def getlist(): return list(map(int, input().split())) #処理内容 def main(): T = int(input()) for _ in range(T): N, K = getlist() ans = pow(2, N * K, con) - pow(2, (N - 1) * K, con) ans *= N ans %= con print(ans) if __name__ == '__main__': main()