#include using namespace std; using ll = long long; #include using mint = atcoder::modint998244353; int main(){ cin.tie(nullptr); ios::sync_with_stdio(false); int n,k; cin>>n>>k; vector dp(n+1,0); dp[1] = 1; dp[2] = mint(k) * mint(k-1); for(int i = 3;i<=n;i++){ dp[i] = mint(k-2) * dp[i-1]; if(i>3) dp[i] += mint(k-1) * dp[i-2]; } mint ans = 0; vector> g(n); vector cnt(n,0); int all = 0; for(int i = 0;i>u>>v; u--;v--; g[u].push_back(v); g[v].push_back(u); cnt[u]++; cnt[v]++; } vector que; for(int i = 0;i