import sys input = sys.stdin.readline N, K = map(int, input().split()) G = [[] for i in range(N)] deg = [0] * N for i in range(N): u, v = map(int, input().split()) u, v = u - 1, v - 1 G[u].append(v) G[v].append(u) deg[u] += 1 deg[v] += 1 Q = [] for i in range(N): if deg[i] == 1: Q.append(i) loop = N while Q: u = Q.pop() loop -= 1 for v in G[u]: deg[v] -= 1 if deg[v] == 1: Q.append(v) pre0, pre1 = K, 0 mod = 998244353 for _ in range(loop - 2): dp0 = pre1 dp1 = pre0 * (K - 1) + pre1 * (K - 2) pre0 = dp0 % mod pre1 = dp1 % mod ans = pre0 * pow(K - 1, N - loop + 1, mod) + pre1 * (K - 2) * pow(K - 1, N - loop, mod) print(ans%mod)