n,k=map(int,input().split()) M=998244353 e=[[] for i in range(n)] d=[0]*n for i in range(n): a,b=map(int,input().split()) a-=1 b-=1 e[a]+=[b] e[b]+=[a] d[a]+=1 d[b]+=1 q1=[0,0] q2=[0,k] for i in range(2,n+2): q1+=[(q1[i-1]*(k-2)+q2[i-1]*(k-1))%M] q2+=[q1[i-1]] if 1 not in d: print(q2[n+1]) exit() for i in range(n): if d[i]==1: S=i break v=[0]*n v[S]=1 g=[0]*n q=[S] f=0 while len(q)>0: s=q[-1] while g[s]0: s=q[-1] while g[s]