for _ in range(int(input())):
    n,m = map(int,input().split()); mod = 998244353
    dp0,dp1,dp2 = 1,n,(n>2)*(n*(n-1)//2-n)
    for _ in range(m-1):
        eq0 = (dp0+dp1+dp2)%mod
        eq1 = (n*dp0+(n-1)*dp1+(n-2)*dp2)%mod
        eq2 = ((n>2)*(n*(n-1)//2-n)*dp0+(n>2)*((n-1)*(n-2)//2-(n-2))*dp1+(n>3)*((n-2)*(n-3)//2-(n-4))*dp2)%mod
        dp0,dp1,dp2 = eq0,eq1,eq2
    print((dp0+dp1+dp2)%mod)