from sys import setrecursionlimit
setrecursionlimit(10**6)
n,k = map(int,input().split())
adj = [[] for _ in range(n)]
for _ in range(n):
    u,v = map(int,input().split())
    adj[u-1].append(v-1); adj[v-1].append(u-1)
def dfs(v,p):
    f[v] = True; x[v] = x[p]+1
    for c in adj[v]:
        if c==p: continue
        if f[c]: return x[v]-x[c]+1
        cnt = dfs(c,v)
        if cnt: return cnt
    return 0
f = [False]*n; x = [0]*n; cnt = dfs(0,-1)
mod = 998244353; p,q = 0,k*(k-1)%mod
for _ in range(cnt-2): p,q = q,((k-1)*p+(k-2)*q)%mod
print(q*pow(k-1,n-cnt,mod)%mod)