mod = 998244353 inv2 = mod + 1 >> 1 t = int(input()) for i in range(t): n, m = map(int, input().split()) if n == 1: print(0) continue p = n * (n - 1) // 2 c = pow(m, n, mod) * ((m - 1) * inv2) % mod print(c * p % mod)