import java.io.BufferedReader; import java.io.InputStreamReader; import java.io.PrintWriter; public class Main { public static void main(String[] args) throws Exception { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); int t = Integer.parseInt(br.readLine()); int mod = 998244353; PrintWriter pw = new PrintWriter(System.out); for (int z = 0; z < t; z++) { String[] sa = br.readLine().split(" "); int n = Integer.parseInt(sa[0]); int m = Integer.parseInt(sa[1]); long[][] dp = new long[3][m]; dp[0][0] = 1; dp[1][0] = n; if (n >= 4) { dp[2][0] = (long) (n - 3) * (n - 2) / 2 + (n - 3); } long n21 = Math.max(dp[2][0] - (n - 3), 0); long n22 = Math.max(dp[2][0] - (n - 3) * 2 + 1, 0); for (int i = 1; i < m; i++) { dp[0][i] += dp[0][i - 1] + dp[1][i - 1] + dp[2][i - 1]; dp[0][i] %= mod; dp[1][i] += dp[0][i - 1] * n % mod; dp[1][i] += dp[1][i - 1] * (n - 1) % mod; dp[1][i] += dp[2][i - 1] * (n - 2) % mod; dp[1][i] %= mod; if (n >= 4) { dp[2][i] += dp[0][i - 1] * dp[2][0] % mod; dp[2][i] += dp[1][i - 1] * n21 % mod; dp[2][i] += dp[2][i - 1] * n22 % mod; dp[2][i] %= mod; } } long ans = dp[0][m - 1] + dp[1][m - 1] + dp[2][m - 1]; ans %= mod; pw.println(ans); } pw.flush(); br.close(); } }