m=998244353 for _ in[0]*int(input()): N,M=map(int,input().split()) print(N*M*M*~-N*~-M//4*pow(M,N-2,m)%m)