結果

問題 No.2362 Inversion Number of Mod of Linear
ユーザー 遭難者遭難者
提出日時 2023-05-06 21:47:05
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 640 ms / 2,000 ms
コード長 2,176 bytes
コンパイル時間 4,305 ms
コンパイル使用メモリ 265,116 KB
実行使用メモリ 4,380 KB
最終ジャッジ日時 2023-09-05 03:29:22
合計ジャッジ時間 6,679 ms
ジャッジサーバーID
(参考情報)
judge11 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
4,376 KB
testcase_01 AC 1 ms
4,380 KB
testcase_02 AC 2 ms
4,376 KB
testcase_03 AC 273 ms
4,380 KB
testcase_04 AC 101 ms
4,380 KB
testcase_05 AC 42 ms
4,376 KB
testcase_06 AC 640 ms
4,376 KB
testcase_07 AC 214 ms
4,376 KB
testcase_08 AC 191 ms
4,380 KB
testcase_09 AC 195 ms
4,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using namespace atcoder;
#define rep(i, n) for (int i = 0; i < n; i++)
#define ALL(a) a.begin(), a.end()
#define ll long long
#define pii pair<int, int>
#define pil pair<int, ll>
#define pli pair<ll, int>
#define vc vector
using mint = modint;
ll N[50], M[50], A[50], B[50];
mint f[50], g[50], h[50];
int c = 0;
void q(ll n, ll m, ll a, ll b)
{
    if (m == 0)
    {
        f[c] = g[c] = h[c] = 0, c++;
        return;
    }
    const int cnt = c;
    const ll a1 = a / m, a2 = a % m;
    const ll b1 = b / m, b2 = b % m;
    N[c] = n, M[c] = m, A[c] = a2, B[c] = b2, c++;
    const ll y = (a2 * n + b2) / m;
    q(y, a2, m, m + a2 - b2 - 1);
    const mint nn = n * (n - 1) / 2;
    f[cnt] = n * mint(y) - f[cnt + 1];
    g[cnt] = n * mint(y) * y - f[cnt + 1] - 2 * h[cnt + 1];
    h[cnt] = nn * mint(y) + (f[cnt + 1] - g[cnt + 1]) / 2;
    g[cnt] += n * mint(2 * n - 1) * (n - 1) * a1 * a1 / 6;
    g[cnt] += 2 * nn * a1 * b1;
    g[cnt] += b1 * mint(b1) * n;
    g[cnt] += 2 * a1 * h[cnt];
    g[cnt] += 2 * b1 * f[cnt];
    f[cnt] += a1 * nn + b1 * n;
    h[cnt] += nn * (a1 * (2 * n - 1) + 3 * b1) / 3;
    return;
}
void solve()
{
    ll n, m, a, b;
    cin >> n >> m >> a >> b;
    {
        ll gccd = gcd(m, a);
        m /= gccd, a /= gccd, b /= gccd;
    }
    vector<ll> r;
    vector<ll> mmod = {998244353, 1000000007};
    for (int mmmod : mmod)
    {
        mint::set_mod(mmmod);
        mint ans = 0;
        c = 0, q(n, m, a, b);
        ans -= (n - 1) * f[0];
        ans += 2 * h[0];
        const long t = a * (n - 1) / m;
        const long y = m - a * (n - 1) + t * m;
        c = 0, q(n, m, a, y);
        ans += f[0];
        ans += h[0];
        ans -= n * mint(n + 1) / 2 * t;
        r.push_back(ans.val());
    }
    ll ans = crt(r, mmod).first;
    const ll k1 = n / m, k2 = n % m;
    ans -= (k1 + 1) * (k1 + 2) / 2 * k2;
    ans -= k1 * (k1 + 1) / 2 * (m - k2);
    cout << ans << '\n';
    return;
}
int main()
{
    cin.tie(nullptr);
    ios::sync_with_stdio(false);
    cout << fixed << setprecision(13);
    int t = 1;
    cin >> t;
    rep(i, t) solve();
    return 0;
}
0