結果

問題 No.1062 素敵なスコア
ユーザー okuraofvegetablokuraofvegetabl
提出日時 2020-05-09 11:46:26
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 218 ms / 2,000 ms
コード長 6,848 bytes
コンパイル時間 1,916 ms
コンパイル使用メモリ 181,612 KB
実行使用メモリ 17,092 KB
最終ジャッジ日時 2024-10-02 09:09:06
合計ジャッジ時間 5,380 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 9 ms
7,936 KB
testcase_01 AC 8 ms
7,936 KB
testcase_02 AC 9 ms
7,936 KB
testcase_03 AC 9 ms
7,808 KB
testcase_04 AC 9 ms
7,808 KB
testcase_05 AC 10 ms
7,936 KB
testcase_06 AC 9 ms
7,808 KB
testcase_07 AC 9 ms
7,936 KB
testcase_08 AC 10 ms
7,936 KB
testcase_09 AC 9 ms
7,936 KB
testcase_10 AC 10 ms
7,808 KB
testcase_11 AC 10 ms
7,808 KB
testcase_12 AC 9 ms
7,936 KB
testcase_13 AC 12 ms
8,064 KB
testcase_14 AC 10 ms
7,936 KB
testcase_15 AC 11 ms
7,936 KB
testcase_16 AC 9 ms
7,808 KB
testcase_17 AC 12 ms
8,064 KB
testcase_18 AC 110 ms
12,236 KB
testcase_19 AC 211 ms
16,600 KB
testcase_20 AC 109 ms
12,360 KB
testcase_21 AC 109 ms
12,508 KB
testcase_22 AC 211 ms
16,672 KB
testcase_23 AC 110 ms
12,472 KB
testcase_24 AC 59 ms
10,136 KB
testcase_25 AC 111 ms
12,248 KB
testcase_26 AC 110 ms
12,508 KB
testcase_27 AC 58 ms
9,916 KB
testcase_28 AC 218 ms
16,580 KB
testcase_29 AC 59 ms
10,268 KB
testcase_30 AC 213 ms
16,648 KB
testcase_31 AC 216 ms
16,956 KB
testcase_32 AC 10 ms
7,808 KB
testcase_33 AC 214 ms
17,092 KB
testcase_34 AC 216 ms
16,872 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

// #pragma GCC optimize("unroll-loops", "omit-frame-pointer", "inline")
// #pragma GCC option("arch=native", "tune=native", "no-zero-upper")
// #pragma GCC
// target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,avx2,tune=native")
// #pragma GCC optimize("Ofast")
// #pragma GCC optimize("tree-vectorize","openmp","predictive-commoning")
// #pragma GCC option("D_GLIBCXX_PARALLEL","openmp")
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;
typedef pair<int, int> P;
typedef vector<int> vi;
typedef vector<ll> vll;
// #define int long long
#define pb push_back
#define mp make_pair
#define eps 1e-9
#define INF 2000000000               // 2e9
#define LLINF 2000000000000000000ll  // 2e18 (llmax:9e18)
#define fi first
#define sec second
#define all(x) (x).begin(), (x).end()
#define sq(x) ((x) * (x))
#define dmp(x) cerr << #x << ": " << x << endl;

template <class T>
void chmin(T &a, const T &b) {
  if (a > b) a = b;
}
template <class T>
void chmax(T &a, const T &b) {
  if (a < b) a = b;
}

template <class T>
using MaxHeap = priority_queue<T>;
template <class T>
using MinHeap = priority_queue<T, vector<T>, greater<T>>;
template <class T>
vector<T> vect(int len, T elem) {
  return vector<T>(len, elem);
}

template <class T, class U>
ostream &operator<<(ostream &os, const pair<T, U> &p) {
  os << p.fi << ',' << p.sec;
  return os;
}
template <class T, class U>
istream &operator>>(istream &is, pair<T, U> &p) {
  is >> p.fi >> p.sec;
  return is;
}
template <class T>
ostream &operator<<(ostream &os, const vector<T> &vec) {
  for (int i = 0; i < vec.size(); i++) {
    os << vec[i];
    if (i + 1 < vec.size()) os << ' ';
  }
  return os;
}
template <class T>
istream &operator>>(istream &is, vector<T> &vec) {
  for (int i = 0; i < vec.size(); i++) is >> vec[i];
  return is;
}
void fastio() {
  cin.tie(0);
  ios::sync_with_stdio(0);
  cout << fixed << setprecision(20);
}

template <int MOD>  // if inv is needed, this shold be prime.
struct ModInt {
  ll val;
  ModInt() : val(0ll) {}
  ModInt(const ll &v) : val(((v % MOD) + MOD) % MOD) {}
  bool operator==(const ModInt &x) const { return val == x.val; }
  bool operator!=(const ModInt &x) const { return !(*this == x); }
  bool operator<(const ModInt &x) const { return val < x.val; }
  bool operator>(const ModInt &x) const { return val > x.val; }
  bool operator>=(const ModInt &x) const { return !(*this < x); }
  bool operator<=(const ModInt &x) const { return !(*this > x); }
  ModInt operator-() const { return ModInt(MOD - val); }
  ModInt inv() const { return this->pow(MOD - 2); }
  ModInt &operator+=(const ModInt &x) {
    if ((val += x.val) >= MOD) val -= MOD;
    return *this;
  }
  ModInt &operator-=(const ModInt &x) {
    if ((val += MOD - x.val) >= MOD) val -= MOD;
    return *this;
  }
  ModInt &operator*=(const ModInt &x) {
    (val *= x.val) %= MOD;
    return *this;
  }
  ModInt &operator/=(const ModInt &x) { return *this *= x.inv(); };
  ModInt operator+(const ModInt &x) const { return ModInt(*this) += x; }
  ModInt operator-(const ModInt &x) const { return ModInt(*this) -= x; }
  ModInt operator*(const ModInt &x) const { return ModInt(*this) *= x; }
  ModInt operator/(const ModInt &x) const { return ModInt(*this) /= x; }
  friend istream &operator>>(istream &i, ModInt &x) {
    ll v;
    i >> v;
    x = v;
    return i;
  }
  friend ostream &operator<<(ostream &o, const ModInt &x) {
    o << x.val;
    return o;
  }
  ModInt pow(ll x) const {
    auto res = ModInt(1ll);
    auto b = *this;
    while (x) {
      if (x & 1) res *= b;
      x >>= 1;
      b *= b;
    }
    return res;
  }
};

template <int MOD>
ModInt<MOD> pow(ModInt<MOD> a, ll x) {
  ModInt<MOD> res = ModInt<MOD>(1ll);
  while (x) {
    if (x & 1) res *= a;
    x >>= 1;
    a *= a;
  }
  return res;
}

// constexpr int MOD = 1e9 + 7;
constexpr int MOD = 998244353;
using mint = ModInt<MOD>;

vector<mint> inv, fac, facinv;
// notice: 0C0 = 1
ModInt<MOD> nCr(int n, int r) {
  assert(!(n < r));
  assert(!(n < 0 || r < 0));
  return fac[n] * facinv[r] * facinv[n - r];
}

void init(int SIZE) {
  fac.resize(SIZE + 1);
  inv.resize(SIZE + 1);
  facinv.resize(SIZE + 1);
  fac[0] = inv[1] = facinv[0] = mint(1ll);
  for (int i = 1; i <= SIZE; i++) fac[i] = fac[i - 1] * mint(i);
  for (int i = 2; i <= SIZE; i++)
    inv[i] = mint(0ll) - mint(MOD / i) * inv[MOD % i];
  for (int i = 1; i <= SIZE; i++) facinv[i] = facinv[i - 1] * inv[i];
  return;
}

template <ll MOD, ll primitive>
class NTT {
 public:
  static ll power_mod(ll x, ll a, ll mod) {
    ll res = 1ll;
    while (a > 0ll) {
      if (a & 1) res = (res * x) % mod;
      x = (x * x) % mod;
      a >>= 1;
    }
    return res;
  }
  static ll get_MOD() { return MOD; }
  static vector<ll> dft(vector<ll> f, int n, int sgn = 1) {
    if (n == 1) return f;
    vector<ll> f0, f1;
    for (int i = 0; i < n / 2; i++) {
      f0.pb(f[i * 2]);
      f1.pb(f[i * 2 + 1]);
    }
    f0 = dft(f0, n / 2, sgn);
    f1 = dft(f1, n / 2, sgn);
    ll zeta = power_mod(primitive, (MOD - 1ll) / (ll)n, MOD);
    if (sgn == -1) zeta = power_mod(zeta, MOD - 2, MOD);
    ll pow_zeta = 1ll;
    for (int i = 0; i < n; i++) {
      f[i] = (f0[i % (n / 2)] + pow_zeta * f1[i % (n / 2)]) % MOD;
      pow_zeta = (pow_zeta * zeta) % MOD;
    }
    return f;
  }
  static vector<ll> idft(vector<ll> f, int n) {
    f = dft(f, n, -1);
    ll ninv = power_mod(n, MOD - 2, MOD);
    for (int i = 0; i < f.size(); i++) { f[i] = (f[i] * ninv) % MOD; }
    return f;
  }
  static vector<ll> mult(vector<ll> A, vector<ll> B) {
    int n = 1;
    while (n < A.size() + B.size() + 1) n <<= 1;
    A.resize(n, 0);
    B.resize(n, 0);
    A = dft(A, n);
    B = dft(B, n);
    vector<ll> C;
    for (int i = 0; i < n; i++) C.pb((A[i] * B[i]) % MOD);
    return idft(C, n);
  }
};

using ntt = NTT<998244353ll, 3ll>;

#define endl "\n"

void solve() {
  int N, A1, A2;
  cin >> N >> A1 >> A2;
  if (A1 > A2) swap(A1, A2);
  int A = A1;
  int B = A2 - A;
  int C = N - A2;
  mint ans;
  for (int i = 0; i < A; i++) { ans += mint(A) * fac[N - 1]; }
  for (int i = 0; i < B; i++) { ans += mint(B) * fac[N - 1]; }
  for (int i = 0; i < C; i++) { ans += mint(C) * fac[N - 1]; }
  vector<ll> f(A), g(C);
  for (int i = 0; i < A; i++) f[i] = nCr(A - 1, i).val;
  for (int i = 0; i < C; i++) g[i] = nCr(C - 1, i).val;
  auto _h = ntt::mult(f, g);
  vector<mint> h(A + C);
  for (int i = 0; i < h.size(); i++) {
    if (N - 2 - i < 0) continue;
    h[i] = mint(_h[i]);
    h[i] *= fac[i] * fac[N - 2 - i];
    ans += h[i] * mint(2ll) * mint(A) * mint(C);
  }
  cout << ans << endl;
  return;
}

signed main() {
  fastio();
  init(200100);
  solve();
  //   int t;
  //   cin >> t;
  //   while (t--) solve();

  // int t; cin >> t;
  // for(int i=1;i<=t;i++){
  //   cout << "Case #" << i << ": ";
  //   solve();
  // }
  return 0;
}
0