結果
問題 | No.2362 Inversion Number of Mod of Linear |
ユーザー | 遭難者 |
提出日時 | 2023-05-07 12:29:56 |
言語 | C++17 (gcc 12.3.0 + boost 1.83.0) |
結果 |
AC
|
実行時間 | 324 ms / 2,000 ms |
コード長 | 2,375 bytes |
コンパイル時間 | 4,615 ms |
コンパイル使用メモリ | 267,148 KB |
実行使用メモリ | 5,376 KB |
最終ジャッジ日時 | 2024-06-23 00:13:11 |
合計ジャッジ時間 | 5,195 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 2 ms
5,248 KB |
testcase_01 | AC | 2 ms
5,248 KB |
testcase_02 | AC | 1 ms
5,248 KB |
testcase_03 | AC | 138 ms
5,248 KB |
testcase_04 | AC | 60 ms
5,376 KB |
testcase_05 | AC | 28 ms
5,376 KB |
testcase_06 | AC | 324 ms
5,376 KB |
testcase_07 | AC | 113 ms
5,376 KB |
testcase_08 | AC | 104 ms
5,376 KB |
testcase_09 | AC | 103 ms
5,376 KB |
ソースコード
#include <bits/stdc++.h> #include <atcoder/all> using namespace std; using namespace atcoder; #define rep(i, n) for (int i = 0; i < n; i++) #define ALL(a) a.begin(), a.end() #define ll long long #define pii pair<int, int> #define pil pair<int, ll> #define pli pair<ll, int> #define vc vector using mint = modint; ll N[50], M[50], A[50], B[50]; mint f[50], g[50], h[50]; int c = 0; mint inv_two, inv_three; void q(ll n, ll m, ll a, ll b) { if (m == 0) { f[c] = g[c] = h[c] = 0, c++; return; } const int cnt = c; const ll a1 = a / m, a2 = a % m; const ll b1 = b / m, b2 = b % m; N[c] = n, M[c] = m, A[c] = a2, B[c] = b2, c++; const ll y = (a2 * n + b2) / m; q(y, a2, m, m + a2 - b2 - 1); const mint nn = inv_two * n * (n - 1); f[cnt] = mint(n) * mint(y) - f[cnt + 1]; g[cnt] = mint(n) * mint(y) * y - f[cnt + 1] - 2 * h[cnt + 1]; h[cnt] = mint(y) * nn + (f[cnt + 1] - g[cnt + 1]) * inv_two; g[cnt] += mint(n) * mint(2 * n - 1) * mint(n - 1) * a1 * a1 * inv_two * inv_three; g[cnt] += mint(2) * nn * a1 * b1; g[cnt] += mint(b1) * mint(b1) * n; g[cnt] += mint(2) * a1 * h[cnt]; g[cnt] += mint(2) * b1 * f[cnt]; f[cnt] += a1 * nn + b1 * n; h[cnt] += nn * (a1 * mint(2 * n - 1) + 3 * b1) * inv_three; return; } void solve() { ll n, m, a, b; cin >> n >> m >> a >> b; { ll gccd = gcd(m, a); m /= gccd, a /= gccd, b /= gccd; } vector<ll> r; vector<ll> mmod = {998244353, 1000000007}; for (int mmmod : mmod) { mint::set_mod(mmmod); inv_two = mint(2).pow(mmmod - 2); inv_three = mint(3).pow(mmmod - 2); mint ans = 0; c = 0, q(n, m, a, b); ans -= (n - 1) * f[0]; ans += 2 * h[0]; const long t = a * (n - 1) / m; const long y = m - a * (n - 1) + t * m; c = 0, q(n, m, a, y); ans += f[0]; ans += h[0]; ans -= n * mint(n + 1) / 2 * t; r.push_back(ans.val()); } ll ans = crt(r, mmod).first; const ll k1 = n / m, k2 = n % m; ans -= (k1 + 1) * (k1 + 2) / 2 * k2; ans -= k1 * (k1 + 1) / 2 * (m - k2); cout << ans << '\n'; return; } int main() { cin.tie(nullptr); ios::sync_with_stdio(false); cout << fixed << setprecision(13); int t = 1; cin >> t; rep(i, t) solve(); return 0; }