mi = lambda: map(int, input().split()) li = lambda: list(mi()) MOD = 998244353 mod = 998244353 MOD2 = 10**9 + 7 mod2 = 10**9 + 7 # memo : len([a,b,...,z])==26 T = int(input()) for _ in range(T): N, M = mi() dp = [[0] * 3 for _ in range(M + 1)] dp[0][0] = 1 if N >= 4: for i in range(M): for j in range(3): dp[i + 1][0] += dp[i][j] dp[i + 1][1] += dp[i][j] * (N - j) dp[i + 1][2] += dp[i][j] * ((N - j) * (N - j - 1) // 2 - (N - 2 * j)) for j in range(3): dp[i + 1][j] %= MOD else: for i in range(M): for j in range(2): dp[i + 1][0] += dp[i][j] dp[i + 1][1] += dp[i][j] * (N - j) for j in range(2): dp[i + 1][j] %= MOD print(sum(dp[M]) % MOD)