R=input MOD=998244353 I=pow(2,-2,MOD) for _ in[0]*int(R()):N,M=map(int,R().split());print(pow(M,N-1,MOD)*M*(M-1)*N*(N-1)*I%MOD)