import sys input = sys.stdin.readline mod=998244353 T=int(input()) for tests in range(T): M,N=map(int,input().split()) k02=0 for i in range(M): if i==0: k02+=max(0,M-3) else: k02+=max(0,M-i-2) k12=0 for i in range(1,M): k12+=max(0,M-i-2) k22=0 for i in range(1,M): if i==0 or i==2: continue else: k22+=max(0,M-i-2) DP=[1,0,0] for i in range(N): NDP=[0,0,0] NDP[0]+=DP[0] NDP[1]+=DP[0]*M NDP[2]+=DP[0]*k02 NDP[0]+=DP[1] NDP[1]+=DP[1]*max(0,M-1) NDP[2]+=DP[1]*k12 NDP[0]+=DP[2] NDP[1]+=DP[2]*max(0,M-2) NDP[2]+=DP[2]*k22 DP=[dp%mod for dp in NDP] #print(DP) print(sum(DP)%mod)