結果

問題 No.108 トリプルカードコンプ
ユーザー mkreemmkreem
提出日時 2024-03-18 22:08:06
言語 C++23
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 14 ms / 5,000 ms
コード長 13,705 bytes
コンパイル時間 3,673 ms
コンパイル使用メモリ 266,404 KB
実行使用メモリ 14,080 KB
最終ジャッジ日時 2024-03-18 22:08:11
合計ジャッジ時間 4,671 ms
ジャッジサーバーID
(参考情報)
judge15 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 10 ms
14,080 KB
testcase_01 AC 9 ms
14,080 KB
testcase_02 AC 9 ms
14,080 KB
testcase_03 AC 8 ms
14,080 KB
testcase_04 AC 9 ms
14,080 KB
testcase_05 AC 9 ms
14,080 KB
testcase_06 AC 9 ms
14,080 KB
testcase_07 AC 14 ms
14,080 KB
testcase_08 AC 8 ms
14,080 KB
testcase_09 AC 9 ms
14,080 KB
testcase_10 AC 9 ms
14,080 KB
testcase_11 AC 9 ms
14,080 KB
testcase_12 AC 9 ms
14,080 KB
testcase_13 AC 10 ms
14,080 KB
testcase_14 AC 9 ms
14,080 KB
testcase_15 AC 9 ms
14,080 KB
testcase_16 AC 11 ms
14,080 KB
testcase_17 AC 13 ms
14,080 KB
testcase_18 AC 12 ms
14,080 KB
testcase_19 AC 10 ms
14,080 KB
testcase_20 AC 9 ms
14,080 KB
testcase_21 AC 11 ms
14,080 KB
testcase_22 AC 8 ms
14,080 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#ifndef INCLUDED_MAIN
#define INCLUDED_MAIN

#include __FILE__

const ll mod = 998244353; using mint = atcoder::modint998244353;
//const ll mod = 1000000007; using mint = atcoder::modint1000000007;

//------------------------------------------------------------------------------------------------------------------

//------------------------------------------------------------------------------------------------------------------



int main(){
  fast();

  int n; cin >> n;
  int cnt0 = 0, cnt1 = 0, cnt2 = 0;
  rep(i, 0, n){
    int a; cin >> a;
    if(a == 0) cnt0++;
    if(a == 1) cnt1++;
    if(a == 2) cnt2++;
  }

  // dp[i][j][k] := 0, 1, 2枚持っているカードがそれぞれi, j, k種類あるとき、引くカード枚数の期待値
  auto dp = Mvector<3, double>(-1, 109, 109, 109);
  dp[0][0][0] = 0; // 自明ケース
  auto f = [&](auto f, int i, int j, int k) -> double{
    if(dp[i][j][k] != -1) return dp[i][j][k];

    double res = 0.0;
    res += (double)n/(i+j+k);
    if(i) res += (double)i/(i+j+k) * f(f, i-1, j+1, k);
    if(j) res += (double)j/(i+j+k) * f(f, i, j-1, k+1);
    if(k) res += (double)k/(i+j+k) * f(f, i, j, k-1);
    return dp[i][j][k] = res;
  };

  PREC;

  cout << f(f, cnt0, cnt1, cnt2) << newl;

}

#else 

#include <bits/stdc++.h>
#include <atcoder/modint>
#include <atcoder/math>
#include <atcoder/lazysegtree.hpp>

void fast(){
  std::ios::sync_with_stdio(false);
  std::cin.tie(nullptr);
}

using ll = long long;
using ld = long double;
#define newl '\n'
#define INF 1000000039
#define LLINF 393939393939393939
#define IMAX INT_MAX
#define IMIN INT_MIN
#define LLMAX LONG_LONG_MAX
#define LLMIN LONG_LONG_MIN
#define PREC std::cout << setprecision(15)
#define PI acos(-1)
#define fore(i, a) for(auto &i : a)
#define rep(i, a, b) for(int i = (a); i < (b); i++)
#define erep(i, a, b) for(int i = (a); i <= (b); i++)
#define rrep(i, a, b) for(int i = (a); i >= (b); i--)
#define all(x) (x).begin(), (x).end()
#define rall(x) (x).rbegin(), (x).rend()
#define pcnt(x) __builtin_popcount(x)
#define llpcnt(x) __builtin_popcountll(x)
template <typename T>
int lwb(const std::vector<T>& vec, T x){
  return lower_bound(all(vec), x) - vec.begin();
}
template <typename T>
int upb(const std::vector<T>& vec, T x){
  return upper_bound(all(vec), x) - vec.begin();
}
template <typename T>
auto max(const T& x){ return *max_element(all(x)); }
template <typename T>
auto min(const T& x){ return *min_element(all(x)); }
template <typename T>
using pq = std::priority_queue<T>;
template <typename T>
using minpq = std::priority_queue<T, std::vector<T>, std::greater<T>>;
// 最大値・最小値の更新
template <typename T1, typename T2>
bool chmax(T1 &a, const T2 &b){
  if(a < b){ a = b; return 1; }
  else return 0;
}
template <typename T1, typename T2>
bool chmin(T1 &a, const T2 &b){
  if(a > b){ a = b; return 1; }
  else return 0;
}
template <typename T>
void iota(std::vector<T>& vec, bool greater = false){
  std::iota(all(vec), 0);
  std::sort(all(vec), [&](int i, int j){
    if(greater) return vec[i] > vec[j];
    return vec[i] < vec[j];
  });
}
// pairのsecondの昇順にソートする比較関数
template <typename T1, typename T2>
bool cmp(std::pair<T1, T2> a, std::pair<T1, T2> b){
  if(a.second != b.second) return a.second < b.second;
  else return a.first < b.first;
}
// 多次元配列の生成
template <size_t Dimention, typename T>
class Mvector : public std::vector<Mvector<Dimention-1, T>>{
public:
  template <typename N, typename... Sizes>
  Mvector(T init, N n, Sizes... sizes) : std::vector<Mvector<Dimention-1, T>>(n, Mvector<Dimention-1, T>(init, sizes...))
  { }
};
template <typename T>
class Mvector<1, T> : public std::vector<T>{
public:
  template <typename N>
  Mvector(T init, N n) : std::vector<T>(n, init)
  { }
};
// 2つのvectorをマージ
template <typename T>
std::vector<T> vmerge(std::vector<T>& a, std::vector<T>& b){
  std::vector<T> res;
  std::sort(a.begin(), a.end());
  std::sort(b.begin(), b.end());
  std::merge(a.begin(), a.end(), b.begin(), b.end(), std::back_inserter(res));
  return res;
}


// 辺
template <typename T>
class Edge{
public:
  int from, to;
  T cost;
  int ID;
  Edge(int to, T cost) : to(to), cost(cost) {} // for WG
  Edge(int from, int to, T cost) : from(from), to(to), cost(cost) {} // for Edges
  Edge(int from, int to, T cost, int ID) : from(from), to(to), cost(cost), ID(ID) {} // for Edges
  bool operator<(const Edge<T>& rhs) const { return cost < rhs.cost; };
  bool operator>=(const Edge<T>& rhs) const { return !(cost < rhs.cost); };
  bool operator>(const Edge<T>& rhs) const { return cost > rhs.cost; };
  bool operator<=(const Edge<T>& rhs) const { return !(cost > rhs.cost); };
  bool operator==(const Edge<T>& rhs) const { return cost == rhs.cost; };
  bool operator!=(const Edge<T>& rhs) const { return !(cost == rhs.cost); };
};
using G = std::vector<std::vector<int>>;
template <typename T>
using WG = std::vector<std::vector<Edge<T>>>;
template <typename T>
using Edges = std::vector<Edge<T>>;

template <typename T, typename F>
// @param ok 解が存在する値
// @param ng 解が存在しない値
// @remark ok > ng の場合は最小値、ok < ng の場合は最大値を返却
T Bsearch(T& ok, T& ng, const F& f){
  while(abs(ok-ng) > 1){
    T mid = (ok+ng)/2;
    (f(mid) ? ok : ng) = mid;
  }
  return ok;
}
template <typename T, typename F>
T Bsearch_double(T& ok, T& ng, const F& f, int itr = 80){
  while(itr--){
    T mid = (ok+ng)/2;
    //T mid = sqrtl(ok*ng);
    (f(mid) ? ok : ng) = mid;
  }
  return ok;
}


template <typename T>
// @brief (k, n-k)-shuffleである0, 1, ..., N-1 の置換Aを、辞書順で列挙する
bool next_shuffle(std::vector<T>& vec, int k){
  int n = vec.size();
  if(n <= k){
    return false;
  }

  // 前K項 := L
  // 後ろN-K項 := R
  auto left = vec.begin();
  auto right = vec.begin() + k;
  T R_max = *std::max_element(right, vec.end());
  T tmp = (std::numeric_limits<T>::min)();
  // @param i Lの要素の中で、Rの要素の最大値よりも小さいもののうち、最大のもののイテレータ(*i := L_(i))
  auto tmpi = left, i = right;
  while(tmpi != right){
    if(tmp <= *tmpi && *tmpi < R_max){
      tmp = *tmpi;
      i = tmpi;
    }
    tmpi++;
  }
  if(i == right){
    return false;
  }

  // @param j Rの要素の中で、L_(i)よりも大きいもののうち、最小のもののイテレータ(*j := R_(j))
  tmp = (std::numeric_limits<T>::max)();
  auto tmpj = right, j = vec.end();
  while(tmpj != vec.end()){
    if(tmp >= *tmpj && *tmpj > *i){
      tmp = *tmpj;
      j = tmpj;
    }
    tmpj++;
  }

  std::iter_swap(i, j); // L_(i)とR_(j)をswap
  i++, j++;
  // やりたいこと:L_(i+1)~L_(k-1)(:= X)とR_(j+1)~R_(n-k-1)(:= Y)を接続し、R_(j+1)が先頭に来るように回転する
  int X_len = k-std::distance(left, i);
  int Y_len = n-k-std::distance(right, j);
  int swap_len = std::min(X_len, Y_len);
  // Xの末尾swap_len項とYの末尾swap_len項をswapする
  std::swap_ranges(right-swap_len, right, j);
  if(swap_len == X_len){
    std::rotate(j, j+swap_len, vec.end());
  }
  else{
    std::rotate(i, right-swap_len, right);
  }

  return true;
}

int log2ll(long long N){
  int B = -1;
  while(N != 0){
    B++;
    N /= 2;
  }
  return B;
}
template <typename T>
// @brief (2,...,2)-shuffleである0, 1, ..., 2*N-1 の置換Aを、辞書順で列挙する
bool next_pairing(std::vector<T>& vec){
  int n = vec.size();
  // @param used vecに含まれるどの数が使用済みであるか
  ll used = 0;
  for(int i = n-1; i >= 0; i--){
    used |= (1<<vec[i]);
    if(i%2 == 1 && vec[i] < log2ll(used)){ // インクリメントできる
      vec[i] = __builtin_ctzll(used >> (vec[i]+1)) + vec[i]+1;
      used ^= (1<<vec[i]);
      for(int j = i+1; j < n; j++){
        vec[j] = __builtin_ctzll(used);
        used ^= (1<<vec[j]);
      }
      return true;
    }
  }
  return false;
}

// @brief 閉区間をsetで管理する
template <typename T>
class RangeSet{
private:
  std::set<std::pair<T, T>> s;
  T sum; // @param sum RangeSet内の要素数
  T TINF = std::numeric_limits<T>::max()/2;

public:
  RangeSet() : sum(T(0)){
    s.emplace(TINF, TINF);
    s.emplace(-TINF, -TINF);
  }

  // @brief 区間[l, r]が完全に含まれているかどうかを返す
  bool covered(const T l, const T r){
    assert(l <= r);
    auto itr = std::prev(s.lower_bound({l+1, l+1}));
    return ((itr->first <= l) && (r <= itr->second));
  }

  // @brief xが含まれているかどうかを返す
  bool contained(const T x){
    auto itr = std::prev(s.lower_bound({x+1, x+1}));
    return ((itr->first <= x) && (x <= itr->second));
  }

  // @brief 区間[l, r]を包含する区間があればその区間を返し、なければ[-INF, -INF]を返す
  std::pair<T, T> covered_by(const T l, const T r){
    assert(l <= r);
    auto itr = std::prev(s.lower_bound({l+1, l+1}));
    if(itr->first <= l && r <= itr->second) return *itr;
    return {-TINF, -TINF};
  }

  std::pair<T, T> covered_by(const T x){
    return covered_by(x, x);
  }

  // @brief 区間[l, r]を挿入し、増分を返す
  T insert(T l, T r){
    assert(l <= r);
    auto itr = std::prev(s.lower_bound({l+1, l+1}));

    if(itr->first <= l && r <= itr->second) return T(0); // [l, r]がすでに完全に含まれている
    T sum_erased = T(0); // @param sum_erased 消した区間の幅の合計
    if(itr->first <= l && l <= itr->second+1){ // l側で、区間itrをマージできる場合
      l = itr->first;
      sum_erased += itr->second - itr->first + 1;
      itr = s.erase(itr);
    }
    else{ // できなかったら、itrを次の区間に進める
      itr = std::next(itr);
    }
    while(r > itr->second){
      sum_erased += itr->second - itr->first + 1;
      itr = s.erase(itr);
    }
    if(itr->first <= r+1 && r <= itr->second){ // r側で、区間itrをマージできる場合
      sum_erased += itr->second - itr->first + 1;
      r = itr->second;
      s.erase(itr);
    }
    s.emplace(l, r);
    sum += r-l+1-sum_erased;
    return r-l+1-sum_erased;
  }

  T insert(const T x){
    return insert(x, x);
  }

  // @brief 区間[l, r]を削除し、減分を返す
  T erase(const T l, const T r){
    assert(l <= r);
    auto itr = std::prev(s.lower_bound({l+1, l+1}));
    if(itr->first <= l && r <= itr->second){ // [l, r]が、1つの区間に包含されている
      // はみ出した区間
      if(itr->first < l) s.emplace(itr->first, l-1);
      if(r < itr->second) s.emplace(r+1, itr->second);
      s.erase(itr);
      sum -= r-l+1;
      return r-l+1;
    }

    T res = T(0);
    if(itr->first <= l && l <= itr->second){ // l側で、区間itrを消せる場合
      res += itr->second-l+1;
      // はみ出した区間
      if(itr->first < l) s.emplace(itr->first, l-1);
      itr = s.erase(itr);
    }
    else{
      itr = std::next(itr);
    }
    while(itr->second <= r){
      res += itr->second - itr->first + 1;
      itr = s.erase(itr);
    }
    if(itr->first <= r && r <= itr->second){ // r側で、区間itrを消せる場合
      res += r-itr->first+1;
      // はみ出した区間
      if(r < itr->second) s.emplace(r+1, itr->second);
      s.erase(itr);
    }
    sum -= res;
    return res;
  }

  T erase(const T x){
    return erase(x, x);
  }

  // @brief 区間の数を返す
  int size() const{
    return (int)s.size()-2;
  }

  /*
  x以上で含まれてない最小の要素は
  ・xが含まれていない:x
  ・xが含まれている:xを含む区間の末端に1加えたもの
  */
  T mex(const T x = 0) const{
    auto itr = std::prev(s.lower_bound({x+1, x+1}));
    if(itr->first <= x && x <= itr->second) return itr->second+1;
    else return x;
  }

  // @brief RangeSet内の要素数を返す
  T sum_all() const{
    return sum;
  }

  // @brief 全区間を保持したsetを返す
  std::set<std::pair<T, T>> get() const{
    std::set<std::pair<T, T>> res;
    for(auto& interval : s) {
      if(std::abs(interval.first) == TINF) continue;
      res.emplace(interval.first, interval.second);
    }
    return res;
  }

  void output() const{
    std::cout << "RangeSet:";
    for(auto& interval : s){
      if(interval.first == -INF || interval.second == INF) continue;
      std::cout << "[" << interval.first << "," << interval.second << "]";
    }
    std::cout << '\n';
  }
};


const int di4[4] = {-1, 0, 1, 0};
const int dj4[4] = {0, 1, 0, -1};
const int di8[8] = {-1, -1, 0, 1, 1, 1, 0, -1};
const int dj8[8] = {0, 1, 1, 1, 0, -1, -1, -1};
const std::vector<std::tuple<int, int, int>> line3{{0,1,2}, {3,4,5}, {6,7,8}, {0,3,6}, {1,4,7}, {2,5,8}, {0,4,8}, {2,4,6}};
const std::vector<std::tuple<int, int, int, int>> line4{{0,1,2,3}, {4,5,6,7}, {8,9,10,11}, {12,13,14,15}, {0,4,8,12}, {1,5,9,13}, {2,6,10,14}, {3,7,11,15}, {0,5,10,15}, {3,6,9,12}};

bool OutOfGrid(int i, int j, int h, int w){
  if(i < 0 || j < 0 || i >= h || j >= w) return true;
  return false;
}

// @brief 繰り返し二乗法を利用した、x^nの求値
ll power(ll x, ll n){
  ll res = 1;

  while(n){
    if(n & 1) res *= x;
    x *= x;
    n >>= 1;
  }
  
  return res;
}

ll power_mod(ll x, ll n, ll m){
  ll res = 1;

  while(n){
    if(n & 1){
      res = (res*x) % m;
    }
    x = (x*x) % m;
    n >>= 1;
  }

  return res;
}

// @brief x/mのfloor(x/m以下の最大の整数)を求める
ll floor(ll x, ll m){
  ll r = (x%m + m) % m; // xをmで割った余り
  return (x-r)/m;
}

// @brief x/mのceil(x/m以上の最小の整数)を求める
ll ceil(ll x, ll m){
  return floor(x+m-1, m); // x/m + (m-1)/m
}



using namespace std;

#endif
0