#include using namespace std; using ll = long long; ll const m = 998244353; int par[200020]; int siz[200020]; void init(int N) { iota(par, par + N, 0); fill(siz, siz + N, 1); } int root(int v) { return par[v] = (v == par[v] ? v : root(par[v])); } bool merge(int a, int b) { a = root(a); b = root(b); if (a == b) { return true; } if (siz[a] < siz[b]) swap(a, b); siz[a] += siz[b]; par[b] = a; return false; } int main () { int N; ll K; cin >> N >> K; int len = 0; init(N); std::vector> gr(N); for (int i = 0; i < N; i ++) { int a, b; cin >> a >> b; if (merge(--a, --b)) { vector D(N, -1); D[a] = 0; queue que; que.push(a); while (!que.empty()) { int u = que.front(); que.pop(); for (auto v : gr[u]) { if (D[v] == -1) { D[v] = D[u] + 1; que.push(v); } } } len = D[b] + 1; } gr[a].push_back(b); gr[b].push_back(a); } ll ans = K; ll hog = K; if (len > 1) { ans = (K * (K - 1)) % m; hog = (K * (K - 1)) % m; } for (int i = 2; i < len; i ++) { ans = (K - 2) * ans + (K - 1) * ((hog - ans + m) % m); ans %= m; hog = (hog * (K - 1)) % m; } for (int i = 0; i < N - len; i ++) { ans = (ans * (K - 1)) % m; } cout << ans << endl; }