mod = 998244353 def solve(): N,M = map(int,input().split()) dp = [[0,0,0] for i in range(M+1)] dp[0][0] = 1 a = N*(N-1)//2 - N a %= mod b = a - N + 3 b %= mod c = a - 2*N + 7 c %= mod if N <= 3: a = b = c = 0 for i in range(M): dp[i+1][0] = sum(dp[i]) dp[i+1][0] %= mod dp[i+1][1] = dp[i][0]*N + dp[i][1]*(N-1) + dp[i][2]*(N-2) dp[i+1][1] %= mod dp[i+1][2] = dp[i][0]*a + dp[i][1]*b + dp[i][2]*c dp[i+1][2] %= mod ans = sum(dp[M]) ans %= mod print(ans) t = int(input()) for _ in range(t): solve()