from sys import setrecursionlimit setrecursionlimit(10**6) n,k = map(int,input().split()) adj = [[] for _ in range(n)] for _ in range(n): u,v = map(int,input().split()) adj[u-1].append(v-1); adj[v-1].append(u-1) def dfs(v,p): f[v] = True; global cnt; cnt += 1 for c in adj[v]: if c==p: continue if f[c]: return True if dfs(c,v): return True cnt -= 1; return False f = [False]*n; cnt = 0; dfs(0,-1) mod = 998244353; p,q = 0,k*(k-1)%mod for _ in range(cnt-2): p,q = q,((k-1)*p+(k-2)*q)%mod print(q*pow(k-1,n-cnt,mod)%mod)