n,k=map(int,input().split()) M=998244353 e=[[] for i in range(n)] d=[0]*n for i in range(n): a,b=map(int,input().split()) a-=1 b-=1 e[a]+=[b] e[b]+=[a] d[a]+=1 d[b]+=1 q1=[0,0,k*(k-1)%M] q2=[0,k,0] for i in range(3,n+2): q1+=[(q1[i-1]*(k-2)+q2[i-1]*(k-1))%M] q2+=[q1[i-1]] if 1 not in d: if k==2: print(2*(n%2==0)) else: if n==3: print(k*(k-1)*(k-2)%M) else: print(q2[n+1]) exit() for i in range(n): if d[i]==1: S=i break v=[0]*n v[S]=1 g=[0]*n q=[S] f=0 while len(q)>0: s=q[-1] while g[s]0: s=q[-1] while g[s]