mod=998244353 t=int(input()) for i in range(t): n,m=map(int,input().split()) dp=[[0 for i in range(3)] for j in range(m+1)] dp[0][0]=1 for j in range(m): dp[j+1][0]+=dp[j][0] dp[j+1][1]+=dp[j][0]*n dp[j+1][0]+=dp[j][1] dp[j+1][1]+=dp[j][1]*(n-1) dp[j+1][0]+=dp[j][2] dp[j+1][1]+=dp[j][2]*(n-2) if n>=4: dp[j+1][2]+=dp[j][0]*(n*(n-1)//2-n) dp[j+1][2]+=dp[j][1]*((n-1)*(n-2)//2-n+2) dp[j+1][2]+=dp[j][2]*((n-2)*(n-3)//2-n+4) dp[j+1][0]%=mod dp[j+1][1]%=mod dp[j+1][2]%=mod print((dp[m][0]+dp[m][1]+dp[m][2])%mod)