T=int(input()) mod = 998244353 ans = [] for _ in range(T): N, M = map(int,input().split()) ans_d = N*(N-1)//2 ans_d *= pow(M, N - 2, mod) # (3*(3-1)//2)**2 ans_d *= (M*(M-1)*M - M*((M*(M-1)//2))) ans_d %= mod ans.append(ans_d) for a in ans: print(a)