#include <stdio.h>
#include <bits/stdc++.h>
#include <atcoder/all>
using namespace atcoder;
using mint = modint998244353;
using namespace std;
#define rep(i,n) for (int i = 0; i < (n); ++i)
#define Inf32 1000000001
#define Inf64 1000000000000000001
struct combi{
	deque<mint> kaijou;
	deque<mint> kaijou_;
	
	combi(int n){
		kaijou.push_back(1);
		for(int i=1;i<=n;i++){
			kaijou.push_back(kaijou[i-1]*i);
		}
		
		mint b=kaijou[n].inv();
		
		kaijou_.push_front(b);
		for(int i=1;i<=n;i++){
			int k=n+1-i;
			kaijou_.push_front(kaijou_[0]*k);
		}
	}
	
	mint combination(int n,int r){
		if(r>n)return 0;
		mint a = kaijou[n]*kaijou_[r];
		a *= kaijou_[n-r];
		return a;
	}
	
	mint junretsu(int a,int b){
		mint x = kaijou_[a]*kaijou_[b];
		x *= kaijou[a+b];
		return x;
	}
	
	mint catalan(int n){
		return combination(2*n,n)/(n+1);
	}
	
};


combi C(3000000);

struct lca{
	
	vector<int> depth;
	vector<vector<int>> parents;
	
	int max_j=18;
	
	lca(int n,vector<vector<int>> &E){
		rep(i,100){
			if((1<<i)>E.size()){
				max_j = i;
				break;
			}
		}
		depth.assign(E.size(),-1);
		parents.assign(E.size(),vector<int>(max_j,-1));
		
		
		
		stack<int> s;
		s.push(n);
		depth[n] = 0;
		while(s.size()!=0){
			int k = s.top();
			s.pop();
			for(int i=0;i<E[k].size();i++){
				int x = E[k][i];
				if(depth[x]!=-1)continue;
				s.push(x);
				depth[x] = depth[k]+1;
				for(int j=0;j<max_j;j++){
					if(j==0){
						parents[x][j] = k;
					}
					else{
						parents[x][j] = parents[parents[x][j-1]][j-1];
					}
					if(parents[x][j] == -1)break;
				}
			}
		}
	}
	
	
	int kth_parent(int u,int k){
		for(int i=0;i<max_j;i++){
			if(k==0)break;
			if(u==-1)break;
			if(k%2){
				u = parents[u][i];
			}
			k/=2;
		}
		return u;
	}
	
	int query(int u,int v){
		if(depth[u]<depth[v])swap(u,v);
		u = kth_parent(u,depth[u]-depth[v]);
		
		if(u==v){
			return u;
		}
		
		for(int i=max_j-1;i>=0;i--){
			if(parents[u][i]!=parents[v][i]){
				u = parents[u][i];
				v = parents[v][i];
			}
		}
		
		return parents[u][0];
		
	}
	
	int get_distance(int u,int v){
		return depth[u]+depth[v]-2*depth[query(u,v)];
	}
	
	
};
int main(){
	
	int N,K;
	cin>>N>>K;
	vector<vector<int>> E(N);
	dsu D(N);
	int x,y;
	rep(i,N){
		int a,b;
		cin>>a>>b;
		a--,b--;
		if(D.same(a,b)){
			x = a,y = b;
		}
		else{
			E[a].push_back(b);
			E[b].push_back(a);
			D.merge(a,b);
		}
	}
	
	lca L(0,E);
	int dis = L.get_distance(x,y);
	dis++;
	vector<mint> dp(2,0);
	dp[0] = K;
	rep(i,dis-1){
		vector<mint> ndp(2,0);
		ndp[0] += dp[1] * 1;
		ndp[1] += dp[0]*(K-1);
		ndp[1] += dp[1] * (K-2);
		swap(dp,ndp);
	}
	mint ans = dp[1];
	ans *= mint(K-1).pow(N - dis);
	cout<<ans.val()<<endl;
	
	return 0;
}