""" https://yukicoder.me/problems/no/2531 """ import math import sys from sys import stdin def modfac(n, MOD): f = 1 factorials = [1] for m in range(1, n + 1): f *= m f %= MOD factorials.append(f) inv = pow(f, MOD - 2, MOD) invs = [1] * (n + 1) invs[n] = inv for m in range(n, 1, -1): inv *= m inv %= MOD invs[m - 1] = inv return factorials, invs def modnCr(n,r): #上で求めたfacとinvsを引数に入れるべし(上の関数で与えたnが計算できる最大のnになる) return fac[n] * inv[n-r] * inv[r] % mod def cycle(r,K): ret = 0 pp = 1 for i in range(r-1): ret += pp * K * pow(K-1,r-1-i,mod) pp *= -1 return ret % mod mod = 998244353 fac,inv = modfac(10**6+100,mod) N,K = map(int,stdin.readline().split()) lis = [ [] for i in range(N) ] inlis = [0] * N #print (cycle(N,K)) for i in range(N): u,v = map(int,stdin.readline().split()) u -= 1 v -= 1 lis[u].append(v) lis[v].append(u) inlis[u] += 1 inlis[v] += 1 from collections import deque q = deque() for i in range(N): if inlis[i] <= 1: q.append(i) tree = 0 while q: v = q.popleft() tree += 1 for nex in lis[v]: inlis[nex] -= 1 if inlis[nex] == 1: q.append(nex) inlis[nex] = -1 print (cycle(N-tree,K) * pow(K-1,tree,mod) % mod)