結果

問題 No.584 赤、緑、青の色塗り
ユーザー tubo28tubo28
提出日時 2017-10-26 21:35:22
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 697 ms / 2,000 ms
コード長 6,253 bytes
コンパイル時間 817 ms
コンパイル使用メモリ 76,964 KB
実行使用メモリ 4,384 KB
最終ジャッジ日時 2023-08-14 02:47:46
合計ジャッジ時間 2,889 ms
ジャッジサーバーID
(参考情報)
judge12 / judge13
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
4,380 KB
testcase_01 AC 1 ms
4,376 KB
testcase_02 AC 2 ms
4,380 KB
testcase_03 AC 1 ms
4,376 KB
testcase_04 AC 1 ms
4,376 KB
testcase_05 AC 2 ms
4,376 KB
testcase_06 AC 2 ms
4,380 KB
testcase_07 AC 2 ms
4,384 KB
testcase_08 AC 2 ms
4,380 KB
testcase_09 AC 2 ms
4,376 KB
testcase_10 AC 2 ms
4,376 KB
testcase_11 AC 2 ms
4,376 KB
testcase_12 AC 2 ms
4,376 KB
testcase_13 AC 2 ms
4,380 KB
testcase_14 AC 11 ms
4,376 KB
testcase_15 AC 697 ms
4,376 KB
testcase_16 AC 5 ms
4,376 KB
testcase_17 AC 3 ms
4,380 KB
testcase_18 AC 3 ms
4,376 KB
testcase_19 AC 473 ms
4,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <iostream>
#include <cassert>
#include <vector>

template <int M, size_t F>
struct mod_int {
    constexpr static signed MODULO = M;
    constexpr static size_t MAX_FACT = F;

    signed x;

    mod_int() : x(0) {}

    mod_int(long long y) : x(static_cast<int>(y >= 0 ? y % MODULO : MODULO - (-y) % MODULO)) { }

    mod_int(int y) : x(y >= 0 ? y % MODULO : MODULO - (-y) % MODULO) { }

    mod_int &operator+=(const mod_int &rhs) {
        if ((x += rhs.x) >= MODULO) x -= MODULO;
        return *this;
    }

    mod_int &operator-=(const mod_int &rhs) {
        if ((x += MODULO - rhs.x) >= MODULO) x -= MODULO;
        return *this;
    }

    mod_int &operator*=(const mod_int &rhs) {
        x = static_cast<int>(1LL * x * rhs.x % MODULO);
        return *this;
    }

    mod_int &operator/=(const mod_int &rhs) {
        assert(rhs.x != 0);
        x = static_cast<int>((1LL * x * rhs.inv().x) % MODULO);
        return *this;
    }

    mod_int operator-() const { return mod_int(-x); }

    mod_int operator+(const mod_int &rhs) const { return mod_int(*this) += rhs; }

    mod_int operator-(const mod_int &rhs) const { return mod_int(*this) -= rhs; }

    mod_int operator*(const mod_int &rhs) const { return mod_int(*this) *= rhs; }

    mod_int operator/(const mod_int &rhs) const { return mod_int(*this) /= rhs; }

    bool operator<(const mod_int &rhs) const { return x < rhs.x; }

    mod_int inv() const {
        signed a = x, b = MODULO, u = 1, v = 0, t;
        while (b) {
            t = a / b;
            a -= t * b;
            std::swap(a, b);
            u -= t * v;
            std::swap(u, v);
        }
        return mod_int(u);
    }

    mod_int pow(long long t) const {
        assert(!(x == 0 && t == 0));
        mod_int e = *this, res = mod_int(1);
        for (; t; e *= e, t >>= 1)
            if (t & 1) res *= e;
        return res;
    }

    mod_int fact() {
        if (_fact[0].x == 0) prepare();
        return _fact[x];
    }

    mod_int choose(mod_int y) {
        assert(y.x <= x);
        return this->fact() / y.fact() / mod_int(x - y.x).fact();
    }

    static mod_int _fact[MAX_FACT + 1];

    static mod_int _inv_fact[MAX_FACT + 1];

    static void prepare() {
        _fact[0] = 1;
        for (int i = 1; i <= MAX_FACT; ++i) {
            _fact[i] = _fact[i - 1] * i;
        }
        _inv_fact[MAX_FACT] = _fact[MAX_FACT].inv();
        for (int i = (int) MAX_FACT - 1; i >= 0; --i) {
            _inv_fact[i] = _inv_fact[i + 1] * (i + 1);
        }
    }
};

template<int M, size_t F>
std::ostream &operator<<(std::ostream &os, const mod_int<M, F> &rhs) {
    return os << rhs.x;
}

template<int M, size_t F>
std::istream &operator>>(std::istream &is, mod_int<M, F> &rhs) {
    long long s;
    is >> s;
    rhs = mod_int<M, F>(s);
    return is;
};

template<int M, size_t F>
mod_int<M, F> mod_int<M, F>::_fact[MAX_FACT + 1];

template<int M, size_t F>
mod_int<M, F> mod_int<M, F>::_inv_fact[MAX_FACT + 1];

template<int M, size_t F>
bool operator == (const mod_int<M, F> &lhs, const mod_int<M, F> &rhs) {
    return lhs.x == rhs.x;
}

template<int M, size_t F>
bool operator != (const mod_int<M, F> &lhs, const mod_int<M, F> &rhs) {
    return !(lhs == rhs);
}

const int MF = 3010;
const int MOD = 1000000007;

using mint = mod_int<MOD, MF>;

mint binom(int n, int r) {
    return (r < 0 || r > n || n < 0) ? 0 : mint(n).choose(r);
}

mint fact(int n) {
    return mint(n).fact();
}

using namespace std;

// K = R + G + B
// K = 1 + 2 + 2 + 2 + 1 + ...
// f(i) := 2がi個,1がk-2i個あるのを並べる
// g(i) := R,G,Bから色が違うペアをi個
// ans += f(i)G(i)

using ll = long long;
ll N, R, G, B;
ll K;

ll naive() {
    vector<int> a;
    if (N < K) return 0;
    for (int i = 0; i < N - K; ++i) {
        a.push_back(0);
    }
    for (int i = 0; i < R; ++i) {
        a.push_back(1);
    }
    for (int j = 0; j < G; ++j) {
        a.push_back(2);
    }
    for (int k = 0; k < B; ++k) {
        a.push_back(3);
    }

    int ans = 0;
    do {
        bool ok = true;
        for (int i = 0; i < N - 1; ++i) {
            if (a[i] && a[i + 1] && a[i] == a[i + 1]) {
                ok = false;
                break;
            }
            if (i + 2 != N && a[i] && a[i + 1] && a[i + 2]) {
                ok = false;
                break;
            }
        }
        if (ok) {
            ++ans;
        }
    } while (next_permutation(a.begin(), a.end()));
    return ans;
}


mint pow2[3010];

mint f(ll single, ll pair) {
    mint res = 1;
    ll space = N - (pair * 2 + single);
    if (space < 0) return 0;
    res *= binom(single + pair, single);
    ll p = single + pair;
    ll q = N - (2 * pair + single);
    res *= binom(q + 1, p);
    return res;
}

mint g(ll single, ll pair) {
    mint res = 0;
    for (ll r = pair - min(G, B); r <= min({pair, R, G + B}); ++r) {
        int single_r = R - r;
        int single_g = G - (pair - r);
        int single_b = B - (pair - r);
        mint x = 1;
        x *= binom(pair, r);
        x *= binom(single, single_r);
        x *= binom(single - single_r + r, single_g);
        x *= binom(single - single_r + r - single_g, single_b);
        res += x;
    }
    return res;
}

int solve(int n, int r, int g, int b) {
    N = n;
    R = r;
    G = g;
    B = b;
    K = R + G + B;
    mint ans = 0;
    for (int pair = 0; pair * 2 <= K; ++pair) {
        ll single = R + G + B - pair * 2;
        ans += ::f(single, pair) * ::g(single, pair) * pow2[pair];
    }
    return ans.x;
}

int main() {
    pow2[0] = 1;
    for (int i = 1; i < 3010; ++i) {
        pow2[i] = pow2[i - 1] * 2;
    }

    if (false) {
        for (int n = 0; n < 30; ++n) {
            for (int r = 0; r <= 3; ++r) {
                for (int g = 0; g <= 3; ++g) {
                    for (int b = 0; b <= 3; ++b) {
                        int res = solve(n, r, g, b);
                        int exp = naive();
                        cout << res << ' ' << exp << endl;
                        assert(res == exp);
                    }
                }
            }
        }
    } else {
        int n, r, g, b;
        cin >> n >> r >> g >> b;
        cout << solve(n, r, g, b) << endl;
    }
}
0