for _ in range(int(input())): n,m = map(int,input().split()); mod = 998244353 dp0,dp1,dp2 = 1,n,(n>2)*(n*(n-1)//2-n) for _ in range(m-1): eq0 = (dp0+dp1+dp2)%mod eq1 = (n*dp0+(n-1)*dp1+(n-2)*dp2)%mod eq2 = ((n>2)*(n*(n-1)//2-n)*dp0+(n>2)*((n-1)*(n-2)//2-(n-2))*dp1+(n>3)*((n-2)*(n-3)//2-(n-4))*dp2)%mod dp0,dp1,dp2 = eq0,eq1,eq2 print((dp0+dp1+dp2)%mod)