結果

問題 No.2362 Inversion Number of Mod of Linear
ユーザー 遭難者遭難者
提出日時 2023-05-07 12:29:56
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 324 ms / 2,000 ms
コード長 2,375 bytes
コンパイル時間 4,615 ms
コンパイル使用メモリ 267,148 KB
実行使用メモリ 5,376 KB
最終ジャッジ日時 2024-06-23 00:13:11
合計ジャッジ時間 5,195 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,248 KB
testcase_02 AC 1 ms
5,248 KB
testcase_03 AC 138 ms
5,248 KB
testcase_04 AC 60 ms
5,376 KB
testcase_05 AC 28 ms
5,376 KB
testcase_06 AC 324 ms
5,376 KB
testcase_07 AC 113 ms
5,376 KB
testcase_08 AC 104 ms
5,376 KB
testcase_09 AC 103 ms
5,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using namespace atcoder;
#define rep(i, n) for (int i = 0; i < n; i++)
#define ALL(a) a.begin(), a.end()
#define ll long long
#define pii pair<int, int>
#define pil pair<int, ll>
#define pli pair<ll, int>
#define vc vector
using mint = modint;
ll N[50], M[50], A[50], B[50];
mint f[50], g[50], h[50];
int c = 0;
mint inv_two, inv_three;
void q(ll n, ll m, ll a, ll b)
{
    if (m == 0)
    {
        f[c] = g[c] = h[c] = 0, c++;
        return;
    }
    const int cnt = c;
    const ll a1 = a / m, a2 = a % m;
    const ll b1 = b / m, b2 = b % m;
    N[c] = n, M[c] = m, A[c] = a2, B[c] = b2, c++;
    const ll y = (a2 * n + b2) / m;
    q(y, a2, m, m + a2 - b2 - 1);
    const mint nn = inv_two * n * (n - 1);
    f[cnt] = mint(n) * mint(y) - f[cnt + 1];
    g[cnt] = mint(n) * mint(y) * y - f[cnt + 1] - 2 * h[cnt + 1];
    h[cnt] = mint(y) * nn + (f[cnt + 1] - g[cnt + 1]) * inv_two;
    g[cnt] += mint(n) * mint(2 * n - 1) * mint(n - 1) * a1 * a1 * inv_two * inv_three;
    g[cnt] += mint(2) * nn * a1 * b1;
    g[cnt] += mint(b1) * mint(b1) * n;
    g[cnt] += mint(2) * a1 * h[cnt];
    g[cnt] += mint(2) * b1 * f[cnt];
    f[cnt] += a1 * nn + b1 * n;
    h[cnt] += nn * (a1 * mint(2 * n - 1) + 3 * b1) * inv_three;
    return;
}
void solve()
{
    ll n, m, a, b;
    cin >> n >> m >> a >> b;
    {
        ll gccd = gcd(m, a);
        m /= gccd, a /= gccd, b /= gccd;
    }
    vector<ll> r;
    vector<ll> mmod = {998244353, 1000000007};
    for (int mmmod : mmod)
    {
        mint::set_mod(mmmod);
        inv_two = mint(2).pow(mmmod - 2);
        inv_three = mint(3).pow(mmmod - 2);
        mint ans = 0;
        c = 0, q(n, m, a, b);
        ans -= (n - 1) * f[0];
        ans += 2 * h[0];
        const long t = a * (n - 1) / m;
        const long y = m - a * (n - 1) + t * m;
        c = 0, q(n, m, a, y);
        ans += f[0];
        ans += h[0];
        ans -= n * mint(n + 1) / 2 * t;
        r.push_back(ans.val());
    }
    ll ans = crt(r, mmod).first;
    const ll k1 = n / m, k2 = n % m;
    ans -= (k1 + 1) * (k1 + 2) / 2 * k2;
    ans -= k1 * (k1 + 1) / 2 * (m - k2);
    cout << ans << '\n';
    return;
}
int main()
{
    cin.tie(nullptr);
    ios::sync_with_stdio(false);
    cout << fixed << setprecision(13);
    int t = 1;
    cin >> t;
    rep(i, t) solve();
    return 0;
}
0