結果

問題 No.1303 Inconvenient Kingdom
コンテスト
ユーザー miscalc
提出日時 2022-10-24 20:28:49
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 11,737 bytes
コンパイル時間 5,406 ms
コンパイル使用メモリ 278,812 KB
最終ジャッジ日時 2025-02-08 12:19:43
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 18 WA * 13 TLE * 3
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using ld = long double;
using pll = pair<ll, ll>;
using tlll = tuple<ll, ll, ll>;
constexpr ll INF = 1LL << 60;
template<class T> bool chmin(T& a, T b) {if (a > b) {a = b; return true;} return false;}
template<class T> bool chmax(T& a, T b) {if (a < b) {a = b; return true;} return false;}
ll safemod(ll A, ll M) {ll res = A % M; if (res < 0) res += M; return res;}
ll divfloor(ll A, ll B) {if (B < 0) A = -A, B = -B; return (A - safemod(A, B)) / B;}
ll divceil(ll A, ll B) {if (B < 0) A = -A, B = -B; return divfloor(A + B - 1, B);}
ll pow_ll(ll A, ll B) {if (A == 0 || A == 1) {return A;} if (A == -1) {return B & 1 ? -1 : 1;} ll res = 1; for (int i = 0; i < B; i++) {res *= A;} return res;}
ll logfloor(ll A, ll B) {assert(A >= 2); ll res = 0; for (ll tmp = 1; tmp <= B / A; tmp *= A) {res++;} return res;}
ll logceil(ll A, ll B) {assert(A >= 2); ll res = 0; for (ll tmp = 1; tmp < B; tmp *= A) {res++;} return res;}
ll arisum_ll(ll a, ll d, ll n) { return n * a + (n & 1 ? ((n - 1) >> 1) * n : (n >> 1) * (n - 1)) * d; }
ll arisum2_ll(ll a, ll l, ll n) { return n & 1 ? ((a + l) >> 1) * n : (n >> 1) * (a + l); }
ll arisum3_ll(ll a, ll l, ll d) { assert((l - a) % d == 0); return arisum2_ll(a, l, (l - a) / d + 1); }
template<class T> void unique(vector<T> &V) {V.erase(unique(V.begin(), V.end()), V.end());}
template<class T> void sortunique(vector<T> &V) {sort(V.begin(), V.end()); V.erase(unique(V.begin(), V.end()), V.end());}
#define FINALANS(A) do {cout << (A) << '\n'; exit(0);} while (false)
template<class T> void printvec(const vector<T> &V) {int _n = V.size(); for (int i = 0; i < _n; i++) cout << V[i] << (i == _n - 1 ? "" : " ");cout << '\n';}
template<class T> void printvect(const vector<T> &V) {for (auto v : V) cout << v << '\n';}
template<class T> void printvec2(const vector<vector<T>> &V) {for (auto &v : V) printvec(v);}
//*
#include <atcoder/all>
using namespace atcoder;
using mint = modint998244353;
//using mint = modint1000000007;
//using mint = modint;
//*/

template<class T>
struct Matrix : vector<vector<T>>
{
  using vector<vector<T>>::vector;
  using vector<vector<T>>::operator=;

  Matrix(int n, int m, T diag = 0, T non_diag = 0)
  {
    (*this) = vector<vector<T>>(n, vector<T>(m, non_diag));
    for (int i = 0; i < min(n, m); i++)
      (*this)[i][i] = diag;
  }
  
  Matrix operator-() const
  {
    int N = (*this).size(), M = (*this)[0].size();
    Matrix res(*this);
    for (int i = 0; i < N; i++)
      for (int j = 0; j < M; j++)
        res[i][j] = -res[i][j];
    return res;
  }

  Matrix &operator+=(const Matrix &A)
  {
    int N = (*this).size(), M = (*this)[0].size();
    assert((int)A.size() == N && (int)A[0].size() == M);
    for (int i = 0; i < N; i++)
      for (int j = 0; j < M; j++)
        (*this)[i][j] += A[i][j];
    return *this;
  }

  Matrix &operator-=(const Matrix &A)
  {
    int N = (*this).size(), M = (*this)[0].size();
    assert((int)A.size() == N && (int)A[0].size() == M);
    for (int i = 0; i < N; i++)
      for (int j = 0; j < M; j++)
        (*this)[i][j] -= A[i][j];
    return *this;
  }

  Matrix &operator*=(const T x)
  {
    int N = (*this).size(), M = (*this)[0].size();
    for (int i = 0; i < N; i++)
      for (int j = 0; j < M; j++)
        (*this)[i][j] *= x;
    return *this;
  }
  Matrix &operator/=(const T x) { return (*this) *= (1 / x); }

  friend Matrix &operator*=(const T x, Matrix &A) { return A *= x; }

  vector<T> operator*(const vector<T> &v) const
  {
    int N = (*this).size(), M = (*this)[0].size();
    assert((int)v.size() == M);

    vector<T> res(N, T(0));
    for (int i = 0; i < N; i++)
      for (int j = 0; j < M; j++)
        res[i] += (*this)[i][j] * v[j];
    return res;
  }

  Matrix operator*(const Matrix &A) const
  {
    int N = (*this).size(), M = (*this)[0].size();
    assert((int)A.size() == M);
    int K = A[0].size();

    Matrix res(N, K, T(0));
    for (int i = 0; i < N; i++)
      for (int j = 0; j < M; j++)
        for (int k = 0; k < K; k++)
          res[i][k] += (*this)[i][j] * A[j][k];
    return res;
  }

  Matrix pow(ll k) const
  {
    int N = (*this).size(), M = (*this)[0].size();
    assert(N == M);
    Matrix res(N, N, T(1)), tmp(*this);
    while (k > 0)
    {
      if (k & 1)
        res *= tmp;
      tmp *= tmp;
      k >>= 1;
    }
    return res;
  }

  Matrix operator+(const Matrix &A) const { return Matrix(*this) += A; }
  Matrix operator-(const Matrix &A) const { return Matrix(*this) -= A; }
  Matrix operator*(const T x) const { return Matrix(*this) *= x; }
  Matrix operator/(const T x) const { return Matrix(*this) /= x; }
  friend Matrix operator*(const T x, Matrix &A) { return matrix(A) *= x; }
  Matrix &operator*=(const Matrix &A) { return (*this) = (*this) * A; }

  T det() const
  {
    Matrix A(*this);
    int N = A.size();
    assert((int)A[0].size() == N);
    T res = T(1);
    for (int i = 0; i < N; i++)
    {
      for (int k = i; k < N; k++)
      {
        if (A[k][i] != T(0))
        {
          for (int l = k - 1; l >= i; l--)
          {
            swap(A[l], A[l + 1]);
            res = -res;
          }
          break;
        }
      }
      if (A[i][i] == T(0))
        return T(0);
      res *= A[i][i];
      T aii_inv = 1 / A[i][i];
      for (int j = 0; j < N; j++)
        A[i][j] *= aii_inv;
      for (int k = i + 1; k < N; k++)
      {
        T aki = A[k][i];
        for (int j = i; j < N; j++)
          A[k][j] -= A[i][j] * aki;
      }
    }
    return res;
  }

  Matrix inv() const // 存在しない場合空配列を返す
  {
    Matrix A(*this);
    int N = A.size();
    assert((int)A[0].size() == N);
    Matrix B(N, N, T(1));
    for (int i = 0; i < N; i++)
    {
      for (int k = i; k < N; k++)
      {
        if (A[k][i] != T(0))
        {
          swap(A[k], A[i]);
          swap(B[k], B[i]);
          break;
        }
      }
      if (A[i][i] == T(0))
        return Matrix(0, 0);
      T aii_inv = 1 / A[i][i];
      for (int j = 0; j < N; j++)
        A[i][j] *= aii_inv, B[i][j] *= aii_inv;
      for (int k = 0; k < N; k++)
      {
        if (k == i)
          continue;
        T aki = A[k][i];
        for (int j = 0; j < N; j++)
        {
          A[k][j] -= A[i][j] * aki;
          B[k][j] -= B[i][j] * aki;
        }
      }
    }
    //assert((*this) * B == matrix(N, N, T(1)));
    return B;
  }

  Matrix row_reduction() const
  {
    Matrix A(*this);
    int N = A.size(), M = A[0].size();
    for (int i = 0, j = 0; i < N && j < M; j++)
    {
      for (int k = i; k < N; k++)
      {
        if (A[k][j] != T(0))
        {
          swap(A[k], A[i]);
          break;
        }
      }
      if (A[i][j] == T(0))
        continue;
      T aij_inv = 1 / A[i][j];
      for (int l = 0; l < M; l++)
        A[i][l] *= aij_inv;
      for (int k = 0; k < N; k++)
      {
        if (k == i)
          continue;
        T akj = A[k][j];
        for (int l = 0; l < M; l++)
          A[k][l] -= A[i][l] * akj;
      }
      i++;
    }
    return A;
  }

  // Ax = b を満たすベクトル x
  // 存在しなければ空
  // 存在すれば, 0 番目に解の一つ, 1 番目以降に基底ベクトルが入ったものを返す
  vector<vector<T>> system(const vector<T> &b) const
  {
    Matrix A(*this);
    int N = A.size(), M = A[0].size();
    assert((int)b.size() == N);
    for (int i = 0; i < N; i++)
      A[i].emplace_back(b[i]);
    A = A.row_reduction();

    vector<int> pivot_i(M, -1), pivot_j(N, -1);
    vector<T> sol(M, T(0));
    for (int j = 0; j < M; j++)
    {
      int k = -1;
      for (int i = 0; i < N; i++)
      {
        if (A[i][j] != T(0))
        {
          if (k == -1)
            k = i;
          else
          {
            k = -2;
            break;
          }
        }
      }
      if (k >= 0 && pivot_j[k] == -1)
      {
        pivot_j[k] = j, pivot_i[j] = k;
        sol[j] = A[k][M];
      }
    }

    if ((*this) * sol != b)
      return vector<vector<T>>();

    vector<vector<T>> basis;
    vector<T> base;
    for (int j = 0; j < M; j++)
    {
      if (pivot_i[j] != -1)
        continue;
      base.assign(M, T(0));
      base[j] = T(1);
      for (int i = 0; i < N; i++)
      {
        if (pivot_j[i] != -1)
          base[pivot_j[i]] = -A[i][j];
      }
      basis.emplace_back(base);
    }

    basis.insert(basis.begin(), sol);
    return basis;
  }
};

// https://yamate11.github.io/blog/posts/2021/04-18-kirchhoff/
// 自己ループのない無向グラフの全域木の個数を求める
// G[i][j] は (i, j) 間を結ぶ辺の本数。G[i][j] = G[j][i] を満たす必要がある
// G[i][i] = 0 のはずだが、そうでない入力に対しても G[i][i] = 0 とみなして計算する
template <class T, class U>
T count_spanning_trees(const vector<vector<U>> &G)
{
  using mat = Matrix<T>;

  int n = G.size();
  for (auto &a : G)
    assert((int)a.size() == n);
  if (n == 1)
    return 1;

  mat A(n - 1, n - 1, 0, 0);
  for (int i = 0; i < n - 1; i++)
  {
    for (int j = i + 1; j < n - 1; j++)
    {
      assert(G[i][j] == G[j][i]);
      A[i][j] = -G[i][j];
      A[j][i] = -G[i][j];
      A[i][i] += G[i][j];
      A[j][j] += G[i][j];
    }
    assert(G[i][n - 1] == G[n - 1][i]);
    A[i][i] += G[i][n - 1];
  }
  return A.det();
}

pair<ll, mint> solve(ll N, ll M, vector<pll> &UV)
{
  vector<vector<ll>> G(N);
  dsu ds(N);
  for (auto [u, v] : UV)
  {
    G.at(u).push_back(v);
    G.at(v).push_back(u);
    ds.merge(u, v);
  }

  mint num = 1;
  if (ds.groups().size() == 1)
  {
    vector<vector<mint>> H(N, vector<mint>(N, 0));
    for (ll i = 0; i < N; i++)
    {
      for (auto j : G.at(i))
        H.at(i).at(j)++;
    }
    num = count_spanning_trees<mint>(H);

    for (auto [u, v] : UV)
    {
      vector<pll> UV2;
      dsu ds2(N);
      for (auto [u2, v2] : UV)
      {
        if (ds2.same(u, u2) && ds2.same(v, v2))
          continue;
        if (ds2.same(u, v2) && ds2.same(v, u2))
          continue;
        ds2.merge(u2, v2);
        UV2.push_back({u2, v2});
      }
      assert(ds2.groups().size() == 2);
      /*
      cout << endl;
      cout << u << " " << v << endl;
      for (auto [u2, v2] : UV2)
        cout << u2 << " " << v2 << endl;
      //*/
      mint tmp = solve(N, UV2.size(), UV2).second - (M - UV2.size());
      num += tmp;
    }
    return {0, num};
  }

  auto gs = ds.groups();
  for (auto g : gs)
  {
    ll K = g.size();
    vector<vector<mint>> H(K, vector<mint>(K, 0));
    vector<ll> inv(N);
    for (ll i = 0; i < K; i++)
    {
      inv.at(g.at(i)) = i;
    }
    for (ll i = 0; i < K; i++)
    {
      ll u = g.at(i);
      for (auto v : G.at(u))
      {
        ll j = inv.at(v);
        H.at(i).at(j)++;
      }
    }
    mint tmp = count_spanning_trees<mint>(H);
    num *= tmp;
  }

  vector<ll> C;
  for (auto g : gs)
    C.push_back(g.size());
  sort(C.begin(), C.end(), greater<ll>());
  ll ans = 0;
  for (ll i = 0; i < N; i++)
  {
    for (ll j = 0; j < N; j++)
    {
      if (!ds.same(i, j))
        ans++;
    }
  }
  if (C.size() >= 2)
  {
    ans -= C.at(0) * C.at(1);
    ans -= C.at(1) * C.at(0);

    while (C.size() > 2)
    {
      if (C.back() < C.at(1))
        C.pop_back();
      else
        break;
    }
    ll L = C.size();
    if (C.at(0) == C.at(1))
      num *= mint(L * (L - 1) / 2) * C.at(0) * C.at(1);
    else
      num *= mint(L - 1) * C.at(0) * C.at(1);
  }
  return {ans, num};
}

int main()
{
  ll N, M;
  cin >> N >> M;
  vector<pll> UV(M);
  for (ll i = 0; i < M; i++)
  {
    ll u, v;
    cin >> u >> v;
    u--, v--;
    UV.at(i) = {u, v};
  }
  auto [ans, num] = solve(N, M, UV);
  cout << ans << endl
       << num.val() << endl;
}
0