結果

問題 No.840 ほむほむほむら
ユーザー yutake2000yutake2000
提出日時 2023-06-11 20:44:23
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 272 ms / 4,000 ms
コード長 11,774 bytes
コンパイル時間 3,059 ms
コンパイル使用メモリ 216,976 KB
実行使用メモリ 4,388 KB
最終ジャッジ日時 2023-08-31 02:05:09
合計ジャッジ時間 7,255 ms
ジャッジサーバーID
(参考情報)
judge15 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
4,384 KB
testcase_01 AC 3 ms
4,384 KB
testcase_02 AC 6 ms
4,384 KB
testcase_03 AC 39 ms
4,384 KB
testcase_04 AC 2 ms
4,384 KB
testcase_05 AC 1 ms
4,380 KB
testcase_06 AC 2 ms
4,380 KB
testcase_07 AC 12 ms
4,384 KB
testcase_08 AC 68 ms
4,380 KB
testcase_09 AC 2 ms
4,384 KB
testcase_10 AC 2 ms
4,384 KB
testcase_11 AC 3 ms
4,380 KB
testcase_12 AC 18 ms
4,388 KB
testcase_13 AC 176 ms
4,380 KB
testcase_14 AC 22 ms
4,380 KB
testcase_15 AC 2 ms
4,384 KB
testcase_16 AC 4 ms
4,380 KB
testcase_17 AC 39 ms
4,380 KB
testcase_18 AC 223 ms
4,384 KB
testcase_19 AC 272 ms
4,384 KB
testcase_20 AC 1 ms
4,384 KB
testcase_21 AC 2 ms
4,384 KB
testcase_22 AC 5 ms
4,380 KB
testcase_23 AC 259 ms
4,380 KB
testcase_24 AC 4 ms
4,384 KB
testcase_25 AC 1 ms
4,384 KB
testcase_26 AC 6 ms
4,380 KB
testcase_27 AC 259 ms
4,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#ifndef ATCODER_MODINT_HPP
#define ATCODER_MODINT_HPP 1

#include <atcoder/internal_math>
#include <atcoder/internal_type_traits>
#include <cassert>
#include <numeric>
#include <type_traits>

#ifdef _MSC_VER
#include <intrin.h>
#endif

namespace atcoder {

namespace internal {

struct modint_base {};
struct static_modint_base : modint_base {};

template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;

}  // namespace internal

template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
    using mint = static_modint;

  public:
    static constexpr int mod() { return m; }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    static_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    static_modint(T v) {
        long long x = (long long)(v % (long long)(umod()));
        if (x < 0) x += umod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    static_modint(T v) {
        _v = (unsigned int)(v % umod());
    }
    static_modint(bool v) { _v = ((unsigned int)(v) % umod()); }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v -= rhs._v;
        if (_v >= umod()) _v += umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        unsigned long long z = _v;
        z *= rhs._v;
        _v = (unsigned int)(z % umod());
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        if (prime) {
            assert(_v);
            return pow(umod() - 2);
        } else {
            auto eg = internal::inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static constexpr unsigned int umod() { return m; }
    static constexpr bool prime = internal::is_prime<m>;
};

template <int id> struct dynamic_modint : internal::modint_base {
    using mint = dynamic_modint;

  public:
    static int mod() { return (int)(bt.umod()); }
    static void set_mod(int m) {
        assert(1 <= m);
        bt = internal::barrett(m);
    }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    dynamic_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        long long x = (long long)(v % (long long)(mod()));
        if (x < 0) x += mod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        _v = (unsigned int)(v % mod());
    }
    dynamic_modint(bool v) { _v = ((unsigned int)(v) % mod()); }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v += mod() - rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        _v = bt.mul(_v, rhs._v);
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        auto eg = internal::inv_gcd(_v, mod());
        assert(eg.first == 1);
        return eg.second;
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static internal::barrett bt;
    static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt = 998244353;

using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;

namespace internal {

template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;

template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;

template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};

template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;

}  // namespace internal

}  // namespace atcoder

#endif  // ATCODER_MODINT_HPP

using namespace atcoder;
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<ll, ll> Pll;
typedef pair<ld, ld> Pdd;
template<typename T>
using MaxHeap = priority_queue<T>;
template<typename T>
using MinHeap = priority_queue<T, vector<T>, greater<T>>;
#define REP(i, n) for(int i = 0; i < n; i++)
#define REPR(i, n) for(int i = n; i >= 0; i--)
#define FOR(i, m, n) for(int i = m; i < n; i++)
#define FORR(i, m, n) for(int i = m; i >= n; i--)
#define INF (ll)1e17
#define ALL(v) v.begin(), v.end()
#define SZ(x) ((int)(x).size())
#define y0 y3487465
#define y1 y8687969
#define j0 j1347829
#define j1 j234892
#define next asdnext
#define prev asdprev
#define bit(n) (1LL<<(n))
#define UNIQUE(v) v.erase( unique(v.begin(), v.end()), v.end() );
#define cauto const auto&
#define pb push_back
#define mp make_pair
#define debug(v) if (debug_mode) cerr << v << "\t";
#define debugln(v) if (debug_mode) cerr << v << "\n";
#define debugP(v) if (debug_mode) cerr << "("  << v.first << ", " << v.second << ")\t";
#define dump(x) if (debug_mode) cerr << #x << " = " << (x) << "\t";
#define SP << " " <<
#define TB << "\t" <<

#ifdef _LOCAL
    bool debug_mode = true;
#else
    bool debug_mode = false;
#endif

void show(const vector<ll>& arr, bool show_index = false, ll w = 4) {
    if (!debug_mode) return;

    ll max_value = 0;
    REP(i, SZ(arr)) {
        if (abs(INF - arr[i]) >= 1e5) max_value = max(max_value, arr[i]);
    }

    w = max(w, SZ(to_string(max_value))+1LL);

    if (show_index) {
        REP(i, arr.size()) {
            cout << right << setw(w) << i;
        }
        cout << endl;
    }

    REP(i, arr.size()){
        if (abs(INF - arr[i]) < 1e5) {
            cout << right << setw(w) << (arr[i] == INF ? "INF" : "inf");
        } else {
            cout << right << setw(w) << arr[i];
        }
    }
    cout << endl;
}

void show(const vector<vector<ll>>& arr, ll w = 4) {
    if (!debug_mode) return;

    int M = arr.size(), N = arr[0].size();
    ll max_value = 0;
    REP(i, M) REP(j, N) {
            if (abs(INF - arr[i][j]) >= 1e5) max_value = max(max_value, arr[i][j]);
        }

    w = max(w, SZ(to_string(max_value))+1LL);

    cout << right << setw(w) << "#";
    REP(i, SZ(arr[0]))  {
        cout << right << setw(w) << i;
    }
    cout << endl;
    REP(i, SZ(arr)) {
        cout << right << setw(w) << i;
        REP(j, SZ(arr[0])) {
            if (abs(INF - arr[i][j]) < 1e5) {
                cout << right << setw(w) << (arr[i][j] == INF ? "INF" : "inf");
            } else {
                cout << right << setw(w) << arr[i][j];
            }
        }
        cout << endl;
    }
    cout << endl;
}

void show(const vector<vector<vector<ll>>>& arr, ll w = 4) {
    if (!debug_mode) return;

    REP(i, arr.size()) {
        cout << "i: " << to_string(i) << endl;
        show(arr[i], w);
    }
    cout << endl;
}

inline vector<vector<vector<ll>>> make_vector(ll i, ll j, ll k) {
    vector<vector<vector<ll>>> v(i, vector<vector<ll>>(j, vector<ll>(k, 0)));
    return v;
}
inline vector<vector<ll>> make_vector(ll i, ll j) {
    vector<vector<ll>> v(i, vector<ll>(j, 0));
    return v;
}

//typedef modint1000000007 mint;
typedef modint998244353 mint;
//typedef modint mint;

int main()
{

    ll N, K;
    cin >> N >> K;

    ll K3 = K * K * K;

    vector<vector<mint>> coeff = vector<vector<mint>>(K3, vector<mint>(K3, 0));

    auto mul = [&](vector<vector<mint>> A, vector<vector<mint>> B) {
        vector<vector<mint>> C = vector<vector<mint>>(K3, vector<mint>(K3, 0));
        REP(i, K3) REP(j, K3) REP(k, K3) {
            C[i][j] += A[i][k] * B[k][j];
        }
        return C;
    };

    auto pow = [&](vector<vector<mint>> A, ll n) {
        vector<vector<mint>> B = vector<vector<mint>>(K3, vector<mint>(K3, 0));
        REP(i, K3) B[i][i] = 1;
        while (n > 0) {
            if (n & 1) B = mul(B, A);
            A = mul(A, A);
            n >>= 1;
        }
        return B;
    };

    auto add = [&](vector<vector<mint>> &A, ll idx2, ll i, ll j, ll k, mint n) {
        i %= K; j %= K; k %= K;
        ll idx = i * K * K + j * K + k;
        A[idx][idx2] += n;
    };

    REP(i, K) REP(j, K) REP(k, K) {
        ll idx = i * K * K + j * K + k;
        add(coeff, idx, i+1, j, k, 1);
        add(coeff, idx, i, j+i, k, 1);
        add(coeff, idx, i, j, k+j, 1);
    }

    coeff = pow(coeff, N);
    
    mint ans = 0;
    REP(i, K) REP(j, K) {
        ll idx = i * K * K + j * K;
        ans += coeff[idx][0];
    }

    cout << ans.val() << endl;

    return 0;
}
0