N,K = map(int,input().split()) G = [[] for i in range(N)] d = [0 for i in range(N)] mod = 998244353 for i in range(N): u,v = map(int,input().split()) G[u-1].append(v-1) G[v-1].append(u-1) d[u-1] += 1 d[v-1] += 1 H = [] for u in range(N): if d[u] == 1: H.append(u) c = 0 while len(H): u = H.pop() for v in G[u]: d[v] -= 1 if d[v] == 1: H.append(v) c += 1 ans = pow(K-1,c,mod) r = N - c dp = [[0,0] for i in range(r)] dp[0][1] = K for i in range(r-1): dp[i+1][0] += (K-2) * sum(dp[i]) + dp[i][1] dp[i+1][0] %= mod dp[i+1][1] += dp[i][0] dp[i+1][1] %= mod ans *= dp[-1][0] ans %= mod print(ans)