結果
問題 | No.2529 Treasure Hunter |
ユーザー |
|
提出日時 | 2023-11-03 23:08:27 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 5,128 bytes |
コンパイル時間 | 2,472 ms |
コンパイル使用メモリ | 214,540 KB |
最終ジャッジ日時 | 2025-02-17 18:47:50 |
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 8 WA * 14 |
ソースコード
#include <bits/stdc++.h> #include <atcoder/modint> #pragma GCC optimize("Ofast") #pragma GCC optimize("unroll-loops") using namespace std; using mint = atcoder::modint998244353; template <typename T> class Matrix { public: Matrix() {} explicit Matrix(int N) : Matrix(N, N) {} explicit Matrix(int H, int W) : mat(H, vector<T>(W)) {} int height() const { return (int) mat.size(); } int width() const { return (int) mat[0].size(); } const std::vector<T> &operator[](int k) const { return mat[k]; } std::vector<T> &operator[](int k) { return mat[k]; } static inline Matrix I(int N) { Matrix ret(N); for(int i = 0; i < N; i++) ret[i][i] = T(1); return ret; } Matrix &operator+=(const Matrix &other) { int H = height(); int W = width(); assert(H == other.height() && W == other.width()); for(int i = 0; i < H; i++) { for(int j = 0; j < W; j++) { (*this)[i][j] += other[i][j]; } } return (*this); } Matrix &operator+=(T X) { int H = height(); int W = width(); for(int i = 0; i < H; i++) { for(int j = 0; j < W; j++) { mat[i][j] += X; } } return (*this); } Matrix &operator-=(const Matrix &other) { int H = height(); int W = width(); assert(H == other.height() && W == other.width()); for(size_t i = 0; i < H; i++) { for(size_t j = 0; j < W; j++) { (*this)[i][j] -= other[i][j]; } } return (*this); } Matrix &operator-=(T X) { int H = height(); int W = width(); for(int i = 0; i < H; i++) { for(int j = 0; j < W; j++) { mat[i][j] -= X; } } return (*this); } Matrix &operator*=(T X) { int H = height(); int W = width(); for(int i = 0; i < H; i++) { for(int j = 0; j < W; j++) { mat[i][j] *= X; } } return (*this); } Matrix &operator/=(T X) { int H = height(); int W = width(); for(int i = 0; i < H; i++) { for(int j = 0; j < W; j++) { mat[i][j] /= X; } } return (*this); } Matrix operator+(const Matrix &other) const { return (Matrix(*this) += other); } Matrix operator+(T X) const { return (Matrix(*this) += X); } Matrix operator-(const Matrix &other) const { return (Matrix(*this) -= other); } Matrix operator-(T X) const { return (Matrix(*this) -= X); } Matrix operator*(T X) const { return (Matrix(*this) *= X); } Matrix operator/(T X) const { return (Matrix(*this) /= X); } Matrix mat_mul(Matrix &other) { int h0 = height(); int w0 = width(); int h1 = other.height(); int w1 = other.width(); assert(w0 == h1); vector<vector<T>> ret(h0, vector<T>(w1, T(0))); for(int i = 0; i < h0; i++) { for(int j = 0; j < w1; j++) { for(int k = 0; k < w0; k++) { ret[i][j] += (*this)[i][k] * other[k][j]; } } } this->mat.swap(ret); return (*this); } Matrix pow(long long k) const { Matrix A = (*this); assert(height() == width()); Matrix ret = Matrix::I(height()); while(k) { if(k & 1) { ret.mat_mul(A); } A.mat_mul(A); k >>= 1LL; } return ret; } Matrix sum() { Matrix A = (*this); T ret = 0; int h = height(); int w = width(); for(int i = 0; i < h; i++) { for(int j = 0; j < w; j++) { ret += A[i][j]; } } return T(ret); } private: std::vector<std::vector<T>> mat; }; void solve() { int N, M; cin >> M >> N; if(M <= 3) { cout << 0 << '\n'; } else if(M % 2 == 0) { Matrix<mint> m(3); m[0][0] = 1; m[1][0] = M; m[2][0] = M * (M / 2 - 2) + M / 2; m[0][1] = 1; m[1][1] = M - 1; m[2][1] = (M - 2) * (M / 2 - 2) + (M - 1) / 2; m[0][2] = 1; m[1][2] = M - 2; m[2][2] = mint(M - 3) * (M - 2) / 2 - (M - 4); Matrix<mint> p = m.pow(N - 1); vector<mint> v = {1, M, mint(M - 3) * (M - 2) / 2 + (M - 3)}; mint ans = 0; for(int i = 0; i < 3; i++) { for(int j = 0; j < 3; j++) { ans += p[i][j] * v[j]; } } cout << ans.val() << '\n'; } else { Matrix<mint> m(3); m[0][0] = 1; m[1][0] = M; m[2][0] = M * (M / 2 - 1); m[0][1] = 1; m[1][1] = M - 1; m[2][1] = (M - 2) * (M / 2 - 1); m[0][2] = 1; m[1][2] = M - 2; m[2][2] = mint(M - 3) * (M - 2) / 2 - (M - 4); Matrix<mint> p = m.pow(N - 1); vector<mint> v = {1, M, mint(M - 3) * (M - 2) / 2 + (M - 3)}; mint ans = 0; for(int i = 0; i < 3; i++) { for(int j = 0; j < 3; j++) { ans += p[i][j] * v[j]; } } cout << ans.val() << '\n'; } } int main() { cin.tie(0); cout.tie(0); ios::sync_with_stdio(false); int T; cin >> T; while(T--) { solve(); } return 0; }