n,k=map(int,input().split()) M=998244353 e=[[] for i in range(n)] d=[0]*n for i in range(n): a,b=map(int,input().split()) a-=1 b-=1 e[a]+=[b] e[b]+=[a] d[a]+=1 d[b]+=1 v=[d[i]==1 for i in range(n)] q=[i for i in range(n) if d[i]==1] for s in q: for t in e[s]: d[t]-=1 if d[t]==1 and v[t]==0: v[t]=1 q+=[t] q1=[0,0] q2=[0,k] for i in range(2,n+2): q1+=[(q1[i-1]*(k-2)+q2[i-1]*(k-1))%M] q2+=[q1[i-1]] l=n-sum(v) print(q2[l+1]*pow(k-1,n-l,M)%M)