結果

問題 No.2996 Floor Sum
ユーザー 2qbingxuan
提出日時 2024-12-22 03:01:40
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,610 ms / 5,000 ms
コード長 6,780 bytes
コンパイル時間 3,600 ms
コンパイル使用メモリ 257,768 KB
実行使用メモリ 36,992 KB
最終ジャッジ日時 2024-12-22 03:01:53
合計ジャッジ時間 12,275 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 12
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

#include <bits/stdc++.h>
using namespace std;
#define all(x) begin(x), end(x)
#ifdef local
#define safe cerr << __LINE__ << " line " << __LINE__ << " safe\n"
#define debug(a...) debug_(#a, a)
#define orange(a...) orange_(#a, a)
template <typename ...T>
void debug_(const char *s, T ...a) {
cerr << "\e[1;32m(" << s << ") = (";
int cnt = sizeof...(T);
(..., (cerr << a << (--cnt ? ", " : ")\e[0m\n")));
}
template <typename I>
void orange_(const char *s, I L, I R) {
cerr << "\e[1;32m[ " << s << " ] = [ ";
for (int f = 0; L != R; ++L)
cerr << (f++ ? ", " : "") << *L;
cerr << " ]\e[0m\n";
}
#else
#define safe ((void)0)
#define debug(...) safe
#define orange(...) safe
#endif
using lld = int64_t;
using llu = uint64_t;
using llf = long double;
using u128 = __uint128_t;
lld fdiv(lld a, lld b)
{ return a / b - (a % b && (a < 0) ^ (b < 0)); }
lld cdiv(lld a, lld b)
{ return a / b + (a % b && (a < 0) ^ (b > 0)); }
/* template <typename T>
T brute(llu a, llu b, llu c, llu n, T U, T R) {
T res;
for (llu i = 1, l = 0; i <= n; i++, res = res * R)
for (llu r = (a*i+b)/c; l < r; ++l) res = res * U;
return res;
} */
template <typename T>
T euclid(llu a, llu b, llu c, llu n, T U, T R) {
if (!n) return T{};
if (b >= c)
return mpow(U, b / c) * euclid(a, b % c, c, n, U, R);
if (a >= c)
return euclid(a % c, b, c, n, U, mpow(U, a / c) * R);
llu m = (u128(a) * n + b) / c;
if (!m) return mpow(R, n);
return mpow(R, (c - b - 1) / a) * U
* euclid(c, (c - b - 1) % a, a, m - 1, R, U)
* mpow(R, n - (u128(c) * m - b - 1) / a);
}
// time complexity is O(log max(a, b, c))
// UUUU R UUUUU R ... UUU R N R R
// k R (ak+b)/c U
template <typename T, T MOD> class Modular {
public:
constexpr Modular() : v() {}
template <typename U> Modular(const U &u) { v = static_cast<T>(0 <= u && u < MOD ? u : (u%MOD+MOD)%MOD); }
template <typename U> explicit operator U() const { return U(v); }
T operator()() const { return v; }
#define REFOP(type, expr...) Modular &operator type (const Modular &rhs) { return expr, *this; }
REFOP(+=, v += rhs.v - MOD, v += MOD & (v >> width)) ; REFOP(-=, v -= rhs.v, v += MOD & (v >> width))
// fits for MOD^2 <= 9e18
REFOP(*=, v = static_cast<T>(1LL * v * rhs.v % MOD)) ; REFOP(/=, *this *= inverse(rhs.v))
#define VALOP(op) friend Modular operator op (Modular a, const Modular &b) { return a op##= b; }
VALOP(+) ; VALOP(-) ; VALOP(*) ; VALOP(/)
Modular operator-() const { return 0 - *this; }
friend bool operator == (const Modular &lhs, const Modular &rhs) { return lhs.v == rhs.v; }
friend bool operator != (const Modular &lhs, const Modular &rhs) { return lhs.v != rhs.v; }
friend std::istream & operator>>(std::istream &I, Modular &m) { T x; I >> x, m = x; return I; }
friend std::ostream & operator<<(std::ostream &O, const Modular &m) { return O << m.v; }
Modular inv() const { return inverse(v); }
Modular qpow(lld p) const {
Modular r = 1, e = *this;
while (p) {
if (p & 1) r *= e;
e *= e;
p >>= 1;
}
return r;
}
private:
constexpr static int width = sizeof(T) * 8 - 1;
T v;
static T inverse(T a) {
// copy from tourist's template
T u = 0, v = 1, m = MOD;
while (a != 0) {
T t = m / a;
m -= t * a; std::swap(a, m);
u -= t * v; std::swap(u, v);
}
assert(m == 1);
return u;
}
};
constexpr int mod = 998244353;
using Mint = Modular<int, mod>;
template <int K>
struct Mat : array<array<Mint, K>, K> {
friend Mat operator*(const Mat &a, const Mat &b) {
Mat c(0);
for (int i = 0; i < K; i++)
for (int j = i; j < K; j++)
for (int k = i; k <= j; k++)
c[i][j] += a[i][k] * b[k][j];
return c;
}
constexpr Mat(int diag = 1) {
for (int i = 0; i < K; i++)
for (int j = 0; j < K; j++)
(*this)[i][j] = diag * (i == j);
}
};
template <typename T>
T mpow(T e, llu n) {
T r;
while (n) {
if (n & 1) r = r * e;
e = e * e;
n >>= 1;
}
return r;
}
constexpr int K = 15;
Mint choose[K][K] = {};
template <int SZ> void solve() {
int p, q;
lld N, M, A, B;
cin >> p >> q >> N >> M >> A >> B;
Mat<SZ> U(0), R(0);
// (p + 1) (q + 1) + 1
// U: x += 1
// R: i += 1
// 1, x
// i, i x
const int sum_index = (p + 1) * (q + 1);
auto enc = [&](int x, int y) {
assert(x <= q && x >= 0 && y <= p && y >= 0);
return x * (p + 1) + y;
};
for (int i = 0; i <= q; i++)
for (int j = 0; j <= i; j++)
for (int z = 0; z <= p; z++)
U[enc(j, z)][enc(i, z)] = choose[i][j];
U[sum_index][sum_index] = 1;
for (int i = 0; i <= p; i++)
for (int j = 0; j <= i; j++)
for (int z = 0; z <= q; z++)
R[enc(z, j)][enc(z, i)] = choose[i][j];
for (int j = 0; j <= p; j++)
R[enc(q, j)][sum_index] = choose[p][j];
R[sum_index][sum_index] = 1;
Mint ans = 0;
if (p == 0) {
ans += Mint(fdiv(B, M)).qpow(q);
}
array<Mint, SZ> init_vector = {};
init_vector[enc(0, 0)] = 1;
bool neg = false;
// fdiv(A * i + B, M) == -cdiv(-A * i - B, M)
// == -fdiv(-A * i - B + M - 1, M)
if (A < 0) {
A = -A;
B = -B + M - 1;
for (int i = 0; i <= q; i++)
for (int j = 0; j < i; j++)
for (int z = 0; z <= p; z++)
if ((j ^ i) & 1)
U[enc(j, z)][enc(i, z)] *= -1;
neg = true;
}
// i: 0 1 2 3
// (9 - 2i) / 6: 1 1 0 0 0 -1
// fdiv(A * i + B, M) = fdiv(A * i + r, M) + fdiv(B, M)
if (B < 0) {
lld quo = fdiv(B, M);
// if neg, then x starts from -quo and keeps decreasing
// if not neg, then x starts from quo and keeps increasing
debug(A, B, quo);
for (int i = 1; i <= q; i++) {
init_vector[enc(i, 0)] = Mint(neg ? -quo : quo).qpow(i);
debug(neg, quo, i);
}
B -= quo * M;
}
debug(A, B, M, N);
for (int i = 0; i < 5; i++)
orange(U[i].begin(), U[i].begin() + 5);
for (int i = 0; i < 5; i++)
orange(R[i].begin(), R[i].begin() + 5);
auto res = euclid(A, B, M, N, U, R);
// Mint cur = 0;
// for (int i = 0; i < SZ; i++) {
// cur += init_vector[i] * res[i][enc(1, 1)];
// }
// debug(cur);
for (int i = 0; i < SZ; i++) {
ans += init_vector[i] * res[i][sum_index];
}
cout << ans << '\n';
// cout << -ans << '\n';
}
signed main() {
cin.tie(nullptr)->sync_with_stdio(false);
for (int i = 0; i < K; i++) {
choose[i][0] = 1;
for (int j = 1; j <= i; j++)
choose[i][j] = choose[i - 1][j] + choose[i - 1][j - 1];
}
int T;
cin >> T;
if (T > 5)
while (T--) solve<10>();
else
while (T--) solve<122>();
}
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0