N,K=map(int, input().split()) D=[[] for i in range(N)] E=[0]*N for i in range(N): u,v=map(int, input().split()) u-=1;v-=1 D[u].append(v);D[v].append(u) E[u]+=1;E[v]+=1 from collections import deque d=deque();V=[-1]*N for i in range(N): if E[i]==1: d.append(i) V[i]=1 while d: now=d.popleft() E[now]-=1 for nex in D[now]: if V[nex]==-1: E[nex]-=1 if E[nex]==1: V[nex]=0 d.append(nex) mod=998244353 e=sum(E)//2 dp1=[1];dp2=[0] for i in range(e-1): p,q=dp1[-1],dp2[-1] dp1.append(q) c=p*(K-1)+q*(K-2) dp2.append(c%mod) ans=dp2[-1]*K ans%=mod ans*=pow(K-1,N-e,mod) ans%=mod print(ans)