#include <stdio.h> #include <bits/stdc++.h> #include <atcoder/all> using namespace atcoder; using mint = modint998244353; using namespace std; #define rep(i,n) for (int i = 0; i < (n); ++i) #define Inf32 1000000001 #define Inf64 1000000000000000001 struct combi{ deque<mint> kaijou; deque<mint> kaijou_; combi(int n){ kaijou.push_back(1); for(int i=1;i<=n;i++){ kaijou.push_back(kaijou[i-1]*i); } mint b=kaijou[n].inv(); kaijou_.push_front(b); for(int i=1;i<=n;i++){ int k=n+1-i; kaijou_.push_front(kaijou_[0]*k); } } mint combination(int n,int r){ if(r>n)return 0; mint a = kaijou[n]*kaijou_[r]; a *= kaijou_[n-r]; return a; } mint junretsu(int a,int b){ mint x = kaijou_[a]*kaijou_[b]; x *= kaijou[a+b]; return x; } mint catalan(int n){ return combination(2*n,n)/(n+1); } }; combi C(3000000); struct lca{ vector<int> depth; vector<vector<int>> parents; int max_j=18; lca(int n,vector<vector<int>> &E){ rep(i,100){ if((1<<i)>E.size()){ max_j = i; break; } } depth.assign(E.size(),-1); parents.assign(E.size(),vector<int>(max_j,-1)); stack<int> s; s.push(n); depth[n] = 0; while(s.size()!=0){ int k = s.top(); s.pop(); for(int i=0;i<E[k].size();i++){ int x = E[k][i]; if(depth[x]!=-1)continue; s.push(x); depth[x] = depth[k]+1; for(int j=0;j<max_j;j++){ if(j==0){ parents[x][j] = k; } else{ parents[x][j] = parents[parents[x][j-1]][j-1]; } if(parents[x][j] == -1)break; } } } } int kth_parent(int u,int k){ for(int i=0;i<max_j;i++){ if(k==0)break; if(u==-1)break; if(k%2){ u = parents[u][i]; } k/=2; } return u; } int query(int u,int v){ if(depth[u]<depth[v])swap(u,v); u = kth_parent(u,depth[u]-depth[v]); if(u==v){ return u; } for(int i=max_j-1;i>=0;i--){ if(parents[u][i]!=parents[v][i]){ u = parents[u][i]; v = parents[v][i]; } } return parents[u][0]; } int get_distance(int u,int v){ return depth[u]+depth[v]-2*depth[query(u,v)]; } }; int main(){ int N,K; cin>>N>>K; vector<vector<int>> E(N); dsu D(N); int x,y; rep(i,N){ int a,b; cin>>a>>b; a--,b--; if(D.same(a,b)){ x = a,y = b; } else{ E[a].push_back(b); E[b].push_back(a); D.merge(a,b); } } lca L(0,E); int dis = L.get_distance(x,y); dis++; vector<mint> dp(2,0); dp[0] = K; rep(i,dis-1){ vector<mint> ndp(2,0); ndp[0] += dp[1] * 1; ndp[1] += dp[0]*(K-1); ndp[1] += dp[1] * (K-2); swap(dp,ndp); } mint ans = dp[1]; ans *= mint(K-1).pow(N - dis); cout<<ans.val()<<endl; return 0; }