import sys input = sys.stdin.readline mod=998244353 N,K=map(int,input().split()) DEG=[0]*(N+5) E=[[] for i in range(N+1)] for i in range(N): x,y=map(int,input().split()) E[x].append(y) E[y].append(x) DEG[x]+=1 DEG[y]+=1 Q=[] for i in range(N+1): if DEG[i]==1: Q.append(i) #print(DEG) while Q: x=Q.pop() DEG[x]-=1 for to in E[x]: DEG[to]-=1 if DEG[to]==1: Q.append(to) L=0 for i in range(N+5): if DEG[i]==2: L+=1 DP=[1,0] for i in range(L-1): NDP=[0,0] NDP[0]=DP[1] NDP[1]=DP[0]*(K-1)+DP[1]*(K-2) DP[0]=NDP[0]%mod DP[1]=NDP[1]%mod ANS=K*DP[1]*pow(K-1,N-L,mod)%mod print(ANS)