結果

問題 No.391 CODING WAR
ユーザー xryuseixxryuseix
提出日時 2021-11-02 17:42:44
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 54 ms / 2,000 ms
コード長 7,342 bytes
コンパイル時間 5,365 ms
コンパイル使用メモリ 131,900 KB
実行使用メモリ 6,912 KB
最終ジャッジ日時 2024-04-20 03:36:53
合計ジャッジ時間 2,750 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 29 ms
6,656 KB
testcase_01 AC 29 ms
6,656 KB
testcase_02 AC 29 ms
6,656 KB
testcase_03 AC 29 ms
6,912 KB
testcase_04 AC 27 ms
6,656 KB
testcase_05 AC 28 ms
6,784 KB
testcase_06 AC 28 ms
6,784 KB
testcase_07 AC 28 ms
6,784 KB
testcase_08 AC 28 ms
6,656 KB
testcase_09 AC 54 ms
6,656 KB
testcase_10 AC 54 ms
6,656 KB
testcase_11 AC 46 ms
6,784 KB
testcase_12 AC 29 ms
6,784 KB
testcase_13 AC 51 ms
6,656 KB
testcase_14 AC 46 ms
6,784 KB
testcase_15 AC 49 ms
6,912 KB
testcase_16 AC 42 ms
6,784 KB
testcase_17 AC 45 ms
6,784 KB
testcase_18 AC 38 ms
6,784 KB
testcase_19 AC 39 ms
6,784 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <algorithm>
#include <bitset>
#include <cassert>
#include <cctype>
#include <cfloat>
#include <climits>
#include <cmath>
#include <cstdio>
#include <deque>
#include <functional>
#include <iomanip>
#include <iostream>
#include <list>
#include <map>
#include <queue>
#include <random>
#include <set>
#include <stack>
#include <string>
#include <unordered_set>
#include <vector>
#pragma GCC target("avx2")
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")
using namespace std;
typedef long double ld;
typedef long long int ll;
typedef unsigned long long int ull;
typedef vector<int> vi;
typedef vector<char> vc;
typedef vector<bool> vb;
typedef vector<double> vd;
typedef vector<string> vs;
typedef vector<ll> vll;
typedef vector<pair<int, int>> vpii;
typedef vector<pair<ll, ll>> vpll;
typedef vector<vi> vvi;
typedef vector<vvi> vvvi;
typedef vector<vc> vvc;
typedef vector<vs> vvs;
typedef vector<vll> vvll;
typedef pair<int, int> P;
typedef map<int, int> mii;
typedef set<int> si;
#define rep(i, n) for (ll i = 0; i < (n); ++i)
#define rrep(i, n) for (int i = 1; i <= (n); ++i)
#define irep(it, stl) for (auto it = stl.begin(); it != stl.end(); it++)
#define drep(i, n) for (int i = (n)-1; i >= 0; --i)
#define fin(ans) cout << (ans) << '\n'
#define STLL(s) strtoll(s.c_str(), NULL, 10)
#define mp(p, q) make_pair(p, q)
#define pb(n) push_back(n)
#define all(a) a.begin(), a.end()
#define Sort(a) sort(a.begin(), a.end())
#define Rort(a) sort(a.rbegin(), a.rend())
#define fi first
#define se second
// #include <atcoder/all>
// using namespace atcoder;
constexpr int dx[8] = {1, 0, -1, 0, 1, -1, -1, 1};
constexpr int dy[8] = {0, 1, 0, -1, 1, 1, -1, -1};
template <class T, class U>
inline bool chmax(T& a, U b) {
    if (a < b) {
        a = b;
        return 1;
    }
    return 0;
}
template <class T, class U>
inline bool chmin(T& a, U b) {
    if (a > b) {
        a = b;
        return 1;
    }
    return 0;
}
template <class T, class U>
ostream& operator<<(ostream& os, pair<T, U>& p) {
    cout << "(" << p.first << ", " << p.second << ")";
    return os;
}
template <class T>
inline void dump(T& v) {
    irep(i, v) { cout << (*i) << ((i == --v.end()) ? '\n' : ' '); }
}
template <class T, class U>
inline void dump(map<T, U>& v) {
    if (v.size() > 100) {
        cout << "ARRAY IS TOO LARGE!!!" << endl;
    } else {
        irep(i, v) { cout << i->first << " " << i->second << '\n'; }
    }
}
template <class T, class U>
inline void dump(pair<T, U>& p) {
    cout << p.first << " " << p.second << '\n';
}
inline void yn(const bool b) { b ? fin("yes") : fin("no"); }
inline void Yn(const bool b) { b ? fin("Yes") : fin("No"); }
inline void YN(const bool b) { b ? fin("YES") : fin("NO"); }
void Case(int i) { printf("Case #%d: ", i); }
const int INF = INT_MAX;
constexpr ll LLINF = 1LL << 61;
constexpr ld EPS = 1e-11;
#define MODTYPE 0
#if MODTYPE == 0
constexpr ll MOD = 1000000007;
#else
constexpr ll MOD = 998244353;
#endif
/* --------------------   ここまでテンプレ   -------------------- */

struct mint {
    ll x;
    mint(ll _x = 0) : x((_x % MOD + MOD) % MOD) {}

    /* 初期化子 */
    mint operator+() const { return mint(x); }
    mint operator-() const { return mint(-x); }

    /* 代入演算子 */
    mint& operator+=(const mint a) {
        if ((x += a.x) >= MOD) x -= MOD;
        return *this;
    }
    mint& operator-=(const mint a) {
        if ((x += MOD - a.x) >= MOD) x -= MOD;
        return *this;
    }
    mint& operator*=(const mint a) {
        (x *= a.x) %= MOD;
        return *this;
    }
    mint& operator/=(const mint a) {
        x *= modinv(a).x;
        x %= MOD;
        return (*this);
    }
    mint& operator%=(const mint a) {
        x %= a.x;
        return (*this);
    }
    mint& operator++() {
        ++x;
        if (x >= MOD) x -= MOD;
        return *this;
    }
    mint& operator--() {
        if (!x) x += MOD;
        --x;
        return *this;
    }
    mint& operator&=(const mint a) {
        x &= a.x;
        return (*this);
    }
    mint& operator|=(const mint a) {
        x |= a.x;
        return (*this);
    }
    mint& operator^=(const mint a) {
        x ^= a.x;
        return (*this);
    }
    mint& operator<<=(const mint a) {
        x *= pow(2, a).x;
        return (*this);
    }
    mint& operator>>=(const mint a) {
        x /= pow(2, a).x;
        return (*this);
    }

    /* 算術演算子 */
    mint operator+(const mint a) const {
        mint res(*this);
        return res += a;
    }
    mint operator-(const mint a) const {
        mint res(*this);
        return res -= a;
    }
    mint operator*(const mint a) const {
        mint res(*this);
        return res *= a;
    }
    mint operator/(const mint a) const {
        mint res(*this);
        return res /= a;
    }
    mint operator%(const mint a) const {
        mint res(*this);
        return res %= a;
    }
    mint operator&(const mint a) const {
        mint res(*this);
        return res &= a;
    }
    mint operator|(const mint a) const {
        mint res(*this);
        return res |= a;
    }
    mint operator^(const mint a) const {
        mint res(*this);
        return res ^= a;
    }
    mint operator<<(const mint a) const {
        mint res(*this);
        return res <<= a;
    }
    mint operator>>(const mint a) const {
        mint res(*this);
        return res >>= a;
    }

    /* 条件演算子 */
    bool operator==(const mint a) const noexcept { return x == a.x; }
    bool operator!=(const mint a) const noexcept { return x != a.x; }
    bool operator<(const mint a) const noexcept { return x < a.x; }
    bool operator>(const mint a) const noexcept { return x > a.x; }
    bool operator<=(const mint a) const noexcept { return x <= a.x; }
    bool operator>=(const mint a) const noexcept { return x >= a.x; }

    /* ストリーム挿入演算子 */
    friend istream& operator>>(istream& is, mint& m) {
        is >> m.x;
        m.x = (m.x % MOD + MOD) % MOD;
        return is;
    }
    friend ostream& operator<<(ostream& os, const mint& m) {
        os << m.x;
        return os;
    }

    /* その他の関数 */
    mint modinv(mint a) { return pow(a, MOD - 2); }
    mint pow(mint x, mint n) {
        mint res = 1;
        while (n.x > 0) {
            if ((n % 2).x) res *= x;
            x *= x;
            n.x /= 2;
        }
        return res;
    }
    mint powll(mint x, ll n) {
        mint res = 1;
        while (n > 0) {
            if (n % 2) res *= x;
            x *= x;
            n /= 2;
        }
        return res;
    }
};

#define MAX_MINT_NCK 201010
mint f[MAX_MINT_NCK], rf[MAX_MINT_NCK];

bool isinit = false;

void init() {
    f[0] = 1;
    rf[0] = 1;
    for (int i = 1; i < MAX_MINT_NCK; i++) {
        f[i] = f[i - 1] * i;
        rf[i] = f[i].pow(f[i], MOD - 2);
    }
}

mint nCk(mint n, mint k) {
    if (n < k) return 0;
    if (!isinit) {
        init();
        isinit = 1;
    }
    mint nl = f[n.x];          // n!
    mint nkl = rf[n.x - k.x];  // (n-k)!
    mint kl = rf[k.x];         // k!
    mint nkk = (nkl.x * kl.x);

    return nl * nkk;
}

int main() {
    ll n, k;
    cin >> n >> k;
    mint ans = 0;
    for (ll i = 0; i <= k; i++) {
        ans += mint(1).pow(-1, i) * nCk(k, i) * mint(1).powll(k - i, n);
    }
    cout << ans << endl;
}
0