結果

問題 No.2996 Floor Sum
ユーザー KudeKude
提出日時 2024-12-21 20:52:17
言語 C++23
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 758 ms / 5,000 ms
コード長 7,412 bytes
コンパイル時間 4,667 ms
コンパイル使用メモリ 286,768 KB
実行使用メモリ 5,248 KB
最終ジャッジ日時 2024-12-21 20:52:26
合計ジャッジ時間 6,063 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 8 ms
5,248 KB
testcase_02 AC 2 ms
5,248 KB
testcase_03 AC 758 ms
5,248 KB
testcase_04 AC 65 ms
5,248 KB
testcase_05 AC 9 ms
5,248 KB
testcase_06 AC 45 ms
5,248 KB
testcase_07 AC 2 ms
5,248 KB
testcase_08 AC 3 ms
5,248 KB
testcase_09 AC 3 ms
5,248 KB
testcase_10 AC 2 ms
5,248 KB
testcase_11 AC 2 ms
5,248 KB
testcase_12 AC 2 ms
5,248 KB
testcase_13 AC 757 ms
5,248 KB
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp:166:5: warning: 'int {anonymous}::RI(int, int)' defined but not used [-Wunused-function]
  166 | int RI(int L, int R) { assert(L < R); return dist_type(L, R - 1)(gen); }
      |     ^~

ソースコード

diff #

#include<bits/stdc++.h>
namespace {
#pragma GCC diagnostic ignored "-Wunused-function"
#include<atcoder/all>
#pragma GCC diagnostic warning "-Wunused-function"
using namespace std;
using namespace atcoder;
#define rep(i,n) for(int i = 0; i < (int)(n); i++)
#define rrep(i,n) for(int i = (int)(n) - 1; i >= 0; i--)
#define all(x) begin(x), end(x)
#define rall(x) rbegin(x), rend(x)
template<class T> bool chmax(T& a, const T& b) { if (a < b) { a = b; return true; } else return false; }
template<class T> bool chmin(T& a, const T& b) { if (b < a) { a = b; return true; } else return false; }
using ll = long long;
using P = pair<int,int>;
using VI = vector<int>;
using VVI = vector<VI>;
using VL = vector<ll>;
using VVL = vector<VL>;
using mint = modint998244353;


vector<mint> power_sums(auto n, int k) {
  vector<mint> fact(k+1), ifact(k+1);
  fact[0] = 1;
  rep(i, k) fact[i+1] = fact[i] * (i+1);
  ifact[k] = fact[k].inv();
  rrep(i, k) ifact[i] = ifact[i+1] * (i+1);
  vector<mint> f(k);
  mint nm = n, n2i = 1;
  // (e^nx-1)/x
  rep(i, k) n2i *= nm, f[i] = ifact[i+1] * n2i;
  // (e^x-1)/x
  rep(i, k) {
    for (int j = 1; i + j < k; j++) {
      f[i+j] -= f[i] * ifact[j+1];
    }
    f[i] *= fact[i];
  }
  return f;
}

auto solve(int p, int q, int n, int m, int a, int b) {
  struct Comb {
    mint d[31][31];
    Comb() {
      d[0][0] = 1;
      rep(i, 30) rep(j, i + 1) d[i+1][j] += d[i][j], d[i+1][j+1] += d[i][j];
    }
    mint operator()(int n, int k) { return d[n][k]; }
  };
  static Comb C;
  assert (n >= 0);
  if (n == 0) return mint();
  n--;
  mint ans = mint(0).pow(p) * mint(b >= 0 ? b / m : (b - (m - 1)) / m).pow(q);
  // (0, n]
  auto to_floor = [&](vector<vector<mint>>& d, int n) {
    auto sms = power_sums(n + 1, p + q + 1);
    sms[0]--;  // 0^0
    for (int i = 0; i <= p + q; i++) {
      // x^p (y^i-(y-1)^i)
      int qmx = p + q - i;
      d[i].resize(qmx + 1 + 1);
      for (int j = qmx + 1; j > 0; j--) {
        d[i][j] = 0;
        for (int dj = 1; dj <= j; dj++) {
          mint v = C(j, dj) * d[i][j-dj];
          d[i][j] += dj % 2 ? v : -v;
        }
      }
      d[i][0] = sms[i];
    }
    return d;
  };
  auto to_ij = [&](vector<vector<mint>>& d) {
    for (int i = 0; i <= p + q; i++) {
      // if (i == 0) {for (auto x : d[i]) cout << x.val() << ' '; cout << endl;}
      int qmx = p + q - i;
      for (int j = 0; j <= qmx; j++) {
        d[i][j] = d[i][j+1] / C(j + 1, j);
        for (int dj = 2; j + dj <= qmx + 1; dj++) {
          mint v = C(j + dj, dj) * d[i][j];
          d[i][j+dj] += dj % 2 == 0 ? v : -v;
        }
      }
      d[i].resize(qmx + 1);
    }
  };
  auto extract_line = [&](int a, ll b, int m, int n) {
    // a/m = a1/m + a2
    // a=a1+a2m
    int a1 = a % m, a2 = a / m;
    if (a1 > 0) a1 -= m, a2++;
    ll fn = (ll)a1 * n + b;
    // fn/m = fn1/m + b2
    // fn = fn1 + b2m
    ll fn1 = fn % m, b2 = fn / m;
    if (fn1 < 0) fn1 += m, b2--;
    ll b1 = b - b2 * m;
    // (a1x+b1)/m + (a2x+b2)
    assert(a1 + a2 * m == a && b1 + b2 * m == b && -m < a1 && a1 <= 0 && 0 <= (ll)a1*n+b1 && (ll)a1*n+b1 < m);
    return tuple(a1, a2, b1, b2);
  };
  auto rec = [&](auto&& self, int a, int b, ll m) -> vector<vector<mint>> {
    // <= p+q
    if (a < b) {
      auto res = self(self, b, a, m);
      rep(i, p + q + 1) rep(j, p + q + 1 - i) if (i < j) swap(res[i][j], res[j][i]);
      return res;
    }
    assert(a > 0);
    // ax+by=m  y=(-ax+m)/b
    int xmx = m / a;
    if (xmx == 0) {
      vector<vector<mint>> res(p + q + 1);
      rep(i, p + q + 1) res[i].resize(p + q + 1 - i);
      return res;
    }
    assert(b > 0);
    auto [a1, a2, b1, b2] = extract_line(-a, m, b, xmx);
    // cout << "X"<<a << ' ' << b << ' ' << m << ' ' << a1 << ' ' << b1 << ' ' << a2 << ' ' << b2 << endl;
    auto d = self(self, -a1, b, b1);
    // cout << "B" << d[0][0].val() << ' ' << d[0][1].val()<<' ' << d[0][2].val()<<' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    // cout << "B" << d[0][1].val() << ' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    to_floor(d, xmx);
    // cout << "B" << d[0][0].val() << ' ' << d[0][1].val()<<' ' << d[0][2].val()<<' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    // cout << "B" << d[1][0].val() << ' ' << d[2][0].val()<<' ' <<' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    rep(i, p + q + 1) rrep(j, p + q + 1 - i) {
      mint s;
      rep(j1, j + 1) rep(j2, j - j1 + 1) {
        int j3 = j - j1 - j2;
        s += d[i+j2][j1] * C(j, j1) * C(j - j1, j2) * mint(a2).pow(j2) * mint(b2).pow(j3);
      }
      d[i][j] = s;
    }
    // cout << "C" << d[0][0].val() << ' ' << d[0][1].val()<<' ' << d[0][2].val()<<' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    // cout << "C" << d[1][0].val() << ' ' << d[2][0].val()<<' ' <<' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    to_ij(d);
    // cout << "B" << d[0][0].val() << ' ' << d[0][1].val()<<' ' << d[0][2].val()<<' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    // cout << "B" << d[0][0].val() << ' ' << xmx << ' ' << a << ' ' << b << ' ' << m << endl;
    return d;
  };
  auto [a1, a2, b1, b2] = extract_line(a, b, m, n);
  // cout << a1 << ' ' << b1 << ' ' << a2 << ' ' << b2 << endl;
  auto res = rec(rec, -a1, m, b1);
  // cout << res[0][0].val() << ' ' << res[0][1].val() << ' ' << res[0][2].val() << endl;
  // cout << res[0][0].val() << ' ' << ans.val() << endl;
  to_floor(res, n);
  // x^p(f+a2x+b2)^q
  // cout << res[0][0].val() << ' ' << ans.val() << endl;
  rep(j1, q + 1) rep(j2, q - j1 + 1) {
    int j3 = q - j1 - j2;
    // x^(p+j2) f^j1
    // cout << j1 << ' ' << j2 << ' ' << j3 << endl;
    ans += res[p+j2][j1] * C(q, j1) * C(q - j1, j2) * mint(a2).pow(j2) * mint(b2).pow(j3);
  }
  return ans;
}

// input
std::default_random_engine gen(std::chrono::system_clock::now().time_since_epoch().count());
using dist_type = std::uniform_int_distribution<>;
using param_type = dist_type::param_type;

int RI(int L, int R) { assert(L < R); return dist_type(L, R - 1)(gen); }


auto floor_div(signed_integral auto x, signed_integral auto y) {
  return x / y - ((x ^ y) < 0 && x % y != 0);
}
template <integral T>
T floor_div(T x, unsigned_integral auto y) {
  return x >= 0 ? T(x / y) : -T(-x / y + (-x % y != 0));
}
auto ceil_div(signed_integral auto x, signed_integral auto y) {
  return x / y + ((x ^ y) >= 0 && x % y != 0);
}
template <integral T>
T ceil_div(T x, unsigned_integral auto y) {
  return x >= 0 ? T(x / y + (x % y != 0)) : -T(-x / y);
}

} int main() {
  ios::sync_with_stdio(false);
  cin.tie(0);
  // rep(_, 1000) {
  //   int p = RI(0, 10), q = RI(0, 10);
  //   int n = RI(0, 200);
  //   int a = RI(-1e9, 1e9), b = RI(-1e9, 1e9), m = RI(1, 1e9);
  //   // tie(p,q,n,m,a,b)=tuple(0,2,2,2,1,1);
  //   mint res = solve(p, q, n + 1, m, a, b);
  //   mint res_naive;
  //   for (int i = 0; i <= n; i++) {
  //     res_naive += mint(i).pow(p) * mint(floor_div((ll)a * i + b, m)).pow(q);
  //   }
  //   if (res != res_naive) {
  //     cout << p << ' ' << q << ' '<< n << ' ' << m << ' ' << a << ' ' << b << endl;
  //     cout << res.val() << endl;
  //     cout << res_naive.val() << endl;
  //     exit(0);
  //   }
  // }
  // cout << "ok" << endl;
  // exit(0);
  int tt;
  cin >> tt;
  while (tt--) {
    int p, q, n, m, a, b;
    cin >> p >> q >> n >> m >> a >> b;
    cout << solve(p, q, n + 1, m, a, b).val() << '\n';
  }
}
0