N,K = map(int,input().split()) P = 998244353 G = [[] for _ in range(N)] Gnum = [0] * N for _ in range(N): u,v = map(int,input().split()) u -= 1 v -= 1 G[u].append(v) G[v].append(u) Gnum[u] += 1 Gnum[v] += 1 t = 0 stack = [] for i in range(N): if Gnum[i] == 1: stack.append((i,-1)) Gnum[i] = 0 t += 1 while stack: now,parent = stack.pop() for v in G[now]: Gnum[v] -= 1 if Gnum[v] == 1: stack.append((v,now)) t += 1 C = N - t dp = [[0] * 2 for _ in range(C)] dp[0][0] = 1 for i in range(1,C): dp[i][0] += dp[i-1][1] dp[i][1] += dp[i-1][0] * (K - 1) + dp[i-1][1] * (K - 2) dp[i][1] %= P ans = dp[-1][1] * K % P ans *= pow(K -1,t,P) print(ans % P)