T=int(input()) mod=998244353 for _ in range(T): N,M=map(int, input().split()) dp=[1,N,max(0,N*(N-1)//2-N)] for i in range(M-1): ndp=[0]*3 ndp[0]+=sum(dp) ndp[1]+=dp[0]*N+dp[1]*(N-1)+dp[2]*(N-2) ndp[2]+=dp[0]*max(0,N*(N-1)//2-N) ndp[2]+=dp[1]*max(0,(N-1)*(N-2)//2-(N-2)) ndp[2]+=dp[2]*max(0,(N-2)*(N-3)//2-(N-4)) dp=ndp for j in range(3): dp[j]%=mod print(sum(dp)%mod)