import sys input = sys.stdin.readline N, K = map(int, input().split()) G = [[] for i in range(N)] for i in range(N): u, v = map(int, input().split()) u, v = u - 1, v - 1 G[u].append(v) G[v].append(u) from collections import * Q = deque([0]) dist = [-1] * N par = [-1] * N loop = -1 dist[0] = 0 while Q and loop == -1: u = Q.popleft() for v in G[u]: if v == par[u]: continue if dist[v] != -1: loop = dist[u] + dist[v] + 1 break par[v] = u dist[v] = dist[u] + 1 Q.append(v) pre0, pre1 = K, 0 mod = 998244353 for _ in range(loop - 2): dp0 = pre1 dp1 = pre0 * (K - 1) + pre1 * (K - 2) pre0 = dp0 % mod pre1 = dp1 % mod ans = pre0 * pow(K - 1, N - loop + 1, mod) + pre1 * (K - 2) * pow(K - 1, N - loop, mod) print(ans%mod)