from sys import setrecursionlimit setrecursionlimit(10**6) n,k = map(int,input().split()) adj = [[] for _ in range(n)] for _ in range(n): u,v = map(int,input().split()) adj[u-1].append(v-1); adj[v-1].append(u-1) def dfs(v,p): f[v] = True; x[v] = x[p]+1 for c in adj[v]: if c==p: continue if f[c]: return x[v]-x[c]+1 cnt = dfs(c,v) if cnt: return cnt return 0 f = [False]*n; x = [0]*n; cnt = dfs(0,-1) mod = 998244353; p,q = 0,k*(k-1)%mod for _ in range(cnt-2): p,q = q,((k-1)*p+(k-2)*q)%mod print(q*pow(k-1,n-cnt,mod)%mod)