結果

問題 No.840 ほむほむほむら
ユーザー noshi91noshi91
提出日時 2019-06-14 21:51:46
言語 C++14
(gcc 13.2.0 + boost 1.83.0)
結果
AC  
実行時間 203 ms / 4,000 ms
コード長 7,505 bytes
コンパイル時間 736 ms
コンパイル使用メモリ 79,380 KB
実行使用メモリ 4,384 KB
最終ジャッジ日時 2023-08-08 23:24:54
合計ジャッジ時間 5,398 ms
ジャッジサーバーID
(参考情報)
judge12 / judge15
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 33 ms
4,380 KB
testcase_01 AC 29 ms
4,376 KB
testcase_02 AC 25 ms
4,380 KB
testcase_03 AC 26 ms
4,376 KB
testcase_04 AC 29 ms
4,376 KB
testcase_05 AC 53 ms
4,376 KB
testcase_06 AC 56 ms
4,376 KB
testcase_07 AC 57 ms
4,380 KB
testcase_08 AC 46 ms
4,376 KB
testcase_09 AC 57 ms
4,376 KB
testcase_10 AC 93 ms
4,380 KB
testcase_11 AC 111 ms
4,380 KB
testcase_12 AC 88 ms
4,376 KB
testcase_13 AC 113 ms
4,380 KB
testcase_14 AC 107 ms
4,376 KB
testcase_15 AC 172 ms
4,376 KB
testcase_16 AC 171 ms
4,380 KB
testcase_17 AC 200 ms
4,376 KB
testcase_18 AC 147 ms
4,384 KB
testcase_19 AC 180 ms
4,380 KB
testcase_20 AC 10 ms
4,380 KB
testcase_21 AC 14 ms
4,376 KB
testcase_22 AC 18 ms
4,376 KB
testcase_23 AC 172 ms
4,376 KB
testcase_24 AC 203 ms
4,376 KB
testcase_25 AC 18 ms
4,376 KB
testcase_26 AC 26 ms
4,376 KB
testcase_27 AC 168 ms
4,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

//#define NDEBUG
#include <cstddef>
#include <cstdint>
#include <vector>

namespace n91 {

using i8 = std::int_fast8_t;
using i32 = std::int_fast32_t;
using i64 = std::int_fast64_t;
using u8 = std::uint_fast8_t;
using u32 = std::uint_fast32_t;
using u64 = std::uint_fast64_t;
using isize = std::ptrdiff_t;
using usize = std::size_t;

constexpr usize operator"" _z(unsigned long long x) {
  return static_cast<usize>(x);
}

class rep {
  const usize f, l;

public:
  class itr {
    friend rep;
    usize i;
    constexpr itr(const usize x) noexcept : i(x) {}

  public:
    void operator++() noexcept { ++i; }
    constexpr usize operator*() const noexcept { return i; }
    constexpr bool operator!=(const itr x) const noexcept { return i != x.i; }
  };
  constexpr rep(const usize first, const usize last) noexcept
      : f(first), l(last) {}
  constexpr itr begin() const noexcept { return itr(f); }
  constexpr itr end() const noexcept { return itr(l); }
};
class revrep {
  const usize f, l;

public:
  class itr {
    friend revrep;
    usize i;
    constexpr itr(usize x) noexcept : i(x) {}

  public:
    void operator++() noexcept { --i; }
    constexpr usize operator*() const noexcept { return i; }
    constexpr bool operator!=(const itr x) const noexcept { return i != x.i; }
  };
  constexpr revrep(usize first, usize last) noexcept : f(--first), l(--last) {}
  constexpr itr begin() const noexcept { return itr(l); }
  constexpr itr end() const noexcept { return itr(f); }
};
template <class T> using vec_alias = std::vector<T>;
template <class T> auto md_vec(const usize n, const T &value) {
  return std::vector<T>(n, value);
}
template <class... Args> auto md_vec(const usize n, Args... args) {
  return std::vector<decltype(md_vec(args...))>(n, md_vec(args...));
}
template <class T> constexpr T difference(const T &a, const T &b) {
  return a < b ? b - a : a - b;
}

} // namespace n91

#include <cstdint>

namespace n91 {

constexpr std::uint_fast64_t totient(std::uint_fast64_t x) noexcept {
  using u64 = std::uint_fast64_t;
  u64 ret = x;
  for (u64 i = static_cast<u64>(2); i * i <= x; ++i) {
    if (x % i == static_cast<u64>(0)) {
      ret -= ret / i;
      x /= i;
      while (x % i == static_cast<u64>(0)) {
        x /= i;
      }
    }
  }
  if (x != static_cast<u64>(1)) {
    ret -= ret / x;
  }
  return ret;
}

template <std::uint_fast64_t Modulus,
          std::uint_fast64_t InverseExp =
              totient(Modulus) - static_cast<u64>(1)>
class modint {
  using u64 = std::uint_fast64_t;

  static_assert(Modulus < static_cast<u64>(1) << static_cast<u64>(32),
                "Modulus must be less than 2**32");

  u64 a;

  constexpr modint &negate() noexcept {
    if (a != static_cast<u64>(0)) {
      a = Modulus - a;
    }
    return *this;
  }

public:
  constexpr modint(const u64 x = static_cast<u64>(0)) noexcept
      : a(x % Modulus) {}
  constexpr u64 &value() noexcept { return a; }
  constexpr u64 value() const noexcept { return a; }
  constexpr modint operator+() const noexcept { return modint(*this); }
  constexpr modint operator-() const noexcept { return modint(*this).negate(); }
  constexpr modint operator+(const modint rhs) const noexcept {
    return modint(*this) += rhs;
  }
  constexpr modint operator-(const modint rhs) const noexcept {
    return modint(*this) -= rhs;
  }
  constexpr modint operator*(const modint rhs) const noexcept {
    return modint(*this) *= rhs;
  }
  constexpr modint operator/(const modint rhs) const noexcept {
    return modint(*this) /= rhs;
  }
  constexpr modint &operator+=(const modint rhs) noexcept {
    a += rhs.a;
    if (a >= Modulus) {
      a -= Modulus;
    }
    return *this;
  }
  constexpr modint &operator-=(const modint rhs) noexcept {
    if (a < rhs.a) {
      a += Modulus;
    }
    a -= rhs.a;
    return *this;
  }
  constexpr modint &operator*=(const modint rhs) noexcept {
    a = a * rhs.a % Modulus;
    return *this;
  }
  constexpr modint &operator/=(modint rhs) noexcept {
    u64 exp = InverseExp;
    while (exp) {
      if (exp % static_cast<u64>(2) != static_cast<u64>(0)) {
        *this *= rhs;
      }
      rhs *= rhs;
      exp /= static_cast<u64>(2);
    }
    return *this;
  }
  constexpr bool operator==(const modint rhs) const noexcept {
    return a == rhs.a;
  }
  constexpr bool operator!=(const modint rhs) const noexcept {
    return a != rhs.a;
  }
};

} // namespace n91

#include <algorithm>
#include <array>
#include <cassert>
#include <cstddef>
#include <initializer_list>

template <class T, std::size_t N, std::size_t M>
class array_matrix : public std::array<std::array<T, M>, N> {
public:
  array_matrix() {
    std::for_each(this->begin(), this->end(), [](auto &row) {
      std::fill(row.begin(), row.end(), static_cast<T>(0));
    });
  }
  array_matrix(std::initializer_list<std::initializer_list<T>> il) {
    assert(il.size() == N);
    auto ar_itr = this->begin();
    const auto ar_end = this->end();
    auto il_itr = il.begin();
    while (ar_itr != ar_end) {
      assert(il_itr->size() == M);
      std::copy(il_itr->begin(), il_itr->end(), ar_itr->begin());
      ++ar_itr;
      ++il_itr;
    }
  }
};

template <class T, std::size_t N, std::size_t M, std::size_t L>
array_matrix<T, N, L> operator*(const array_matrix<T, N, M> &lhs,
                                const array_matrix<T, M, L> &rhs) {
  array_matrix<T, N, L> ret = {};
  for (std::size_t i = static_cast<std::size_t>(0); i != N; ++i) {
    for (std::size_t j = static_cast<std::size_t>(0); j != M; ++j) {
      for (std::size_t k = static_cast<std::size_t>(0); k != L; ++k) {
        ret[i][k] += lhs[i][j] * rhs[j][k];
      }
    }
  }
  return ret;
}

template <class T, std::size_t N> array_matrix<T, N, N> identity() {
  array_matrix<T, N, N> ret;
  for (std::size_t i = static_cast<std::size_t>(0); i != N; ++i) {
    ret[i][i] = static_cast<T>(1);
  }
  return ret;
}

#include <functional>
#include <utility>

namespace n91 {

template <class T, class U, class Operate = std::multiplies<T>>
constexpr T power(T base, U exp, const Operate &oper = Operate(),
                  T iden = static_cast<T>(1)) {
  while (exp != static_cast<U>(0)) {
    if (exp % static_cast<U>(2) != static_cast<U>(0)) {
      iden = oper(iden, base);
    }
    exp /= static_cast<U>(2);
    base = oper(base, base);
  }
  return iden;
}

} // namespace n91

#include <algorithm>
#include <iostream>
#include <utility>

namespace n91 {

void main_() {
  u64 n;
  usize k;
  std::cin >> n >> k;
  using mint = modint<998244353>;
  array_matrix<mint, 125_z, 125_z> trans;
  const auto get = [&k](const usize i, const usize j, const usize s) {
    return k * k * i + k * j + s;
  };
  for (const auto i : rep(0_z, k)) {
    for (const auto j : rep(0_z, k)) {
      for (const auto s : rep(0_z, k)) {
        trans[get(i, j, s)][get((i + 1_z) % k, j, s)] += static_cast<mint>(1);
        trans[get(i, j, s)][get(i, (j + i) % k, s)] += static_cast<mint>(1);
        trans[get(i, j, s)][get(i, j, (s + j) % k)] += static_cast<mint>(1);
      }
    }
  }
  array_matrix<mint, 1_z, 125_z> vec;
  vec[0_z][0_z] = static_cast<mint>(1);
  auto ret =
      vec * power(trans, n, std::multiplies<>(), identity<mint, 125_z>());
  mint ans = static_cast<mint>(0);
  for (const auto i : rep(0, k)) {
    for (const auto j : rep(0, k)) {
      ans += ret[0_z][get(i, j, 0_z)];
    }
  }
  std::cout << ans.value() << std::endl;
}

} // namespace n91

int main() {
  n91::main_();
  return 0;
}
0