J=lambda:map(int,input().split()) N,K=J() L=range(N) E=[0]+[set()for i in L] for i in L: u,v=J() E[u].add(v),E[v].add(u) l=N for i in L: r=i+1 while len(E[r])==1: l,j=l-1,E[r].pop() E[j].remove(r) r=j P=998244353 print(pow(K-1,N-l+1,P)*(pow(K-1,l-1,P)+1-l%2*2)%P)