#include typedef long long ll; using namespace std; int n, f[200010], vis[200010]; ll k, dp[3][200010], ret, c; const ll mod = 998244353; vectora[200010]; void dfs(int u, int pre) { vis[u] = 1; for (int v : a[u]) if (v != pre) { if (vis[v]) c = abs(f[v] - f[u]) + 1; else f[v] = f[u] + 1, dfs(v, u); } } int main() { cin >> n >> k; for (int i=1; i<=n; i++) { int x, y; cin >> x >> y; a[x].push_back(y); a[y].push_back(x); } dfs(1, 0); dp[1][1] = k; for (int i=2; i<=c; i++) { dp[1][i] = dp[2][i-1]; dp[2][i] = (dp[1][i-1] * (k-1) + dp[2][i-1] * (k-2)) % mod; } ret = dp[2][c]; for (int i=1; i<=n-c; i++) ret = (ret * (k-1)) % mod; cout << ret; }