結果

問題 No.2327 Inversion Sum
コンテスト
ユーザー InTheBloom
提出日時 2026-02-03 13:06:36
言語 D
(dmd 2.111.0)
結果
AC  
実行時間 71 ms / 2,000 ms
コード長 8,349 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 5,882 ms
コンパイル使用メモリ 204,888 KB
実行使用メモリ 7,972 KB
最終ジャッジ日時 2026-02-03 13:06:50
合計ジャッジ時間 8,878 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 30
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import std;

void main () {
    int N, M;
    readln.read(N, M);
    auto P = new int[](M);
    auto K = new int[](M);
    foreach (i; 0 .. M) {
        readln.read(P[i], K[i]);
        P[i]--;
        K[i]--;
    }

    writeln(solve(N, M, P, K));
}

long naive (int N, int M, const int[] P, const int[] K) {
    long ans = 0;
    auto arr = iota(N).array;
    do {
        bool ok = true;
        foreach (i; 0 .. M) {
            foreach (j; 0 .. N) {
                if (arr[j] == P[i] && j != K[i]) {
                    ok = false;
                }
            }
        }
        if (!ok) {
            continue;
        }

        foreach (i; 0 .. N) {
            foreach (j; 0 .. i) {
                if (arr[i] < arr[j]) {
                    ans++;
                }
            }
        }
    } while (nextPermutation(arr));
    return ans;
}

long solve (int N, int M, const int[] P, const int[] K) {
    const long MOD = 998244353;

    // 転倒数は全ての組に対しての和を考えればよいので、2グループに分けて考える。
    // 固定された値集合X、固定されていない値集合Yとする。
    // 1. X-X: 通常の転倒数アルゴリズムで計算可能。未確定値分|Y|!倍されることに注意。
    // 2. Y-Y: 0 .. |Y|のすべての順列に対する転倒数の総和の計算に帰着。
    //         ある値p, qの転倒数寄与を考えてみる。
    //         任意の順列に対してpとqを入れ替えた順列が存在することから、
    //         |Y|! / 2通りでカウントされることがわかる。
    //         p, qは任意に取れることから、|Y| * (|Y| - 1) * |Y|! / 2 / 2通り。
    // 3. X-Y: Xの各元について寄与を考える。以下Xの元xを考える。寄与は2種類で、
    //         1. xの前に大きなYの元が入る
    //         2. xの後に小さなYの元が入る
    //         あるYの元yを一つとってxとyの寄与を考えると、
    //         x < yならタイプ1のみ、y < xならタイプ2のみ考えればよい。
    //         ある(x, y)は他と独立で考えてよいため、空き地の数 * |Y - 1|!通りに増える。

    long ans = 0;
    auto seg = new SegmentTree!(int, (int a, int b) => a + b, () => 0)(N);
    auto fac = new long[](N + 1);
    fac[0] = 1;
    foreach (i; 1 .. N + 1) {
        fac[i] = fac[i - 1] * i % MOD;
    }

    // 1
    auto ord = iota(M).array;
    ord.sort!((a, b) => K[a] < K[b]);
    foreach (i; ord) {
        ans += seg.prod(P[i] + 1, N);
        seg.set(P[i], 1);
    }
    ans %= MOD;
    ans *= fac[N - M];
    ans %= MOD;

    // 2
    ans += 1L * (N - M) * (N - M - 1) % MOD * mod_inv(4, MOD) % MOD * fac[N - M] % MOD;
    ans %= MOD;

    // 3
    foreach (idx, i; ord.enumerate(0)) {
        // x < y
        {
            int k = K[i] - idx;
            int m = N - P[i] - seg.prod(P[i], N);
            ans += 1L * k * m % MOD * fac[N - M - 1] % MOD;
        }

        // y < x
        {
            int k = N - K[i] - 1 - (M - idx - 1);
            int m = P[i] - seg.prod(0, P[i]);
            ans += 1L * k * m % MOD * fac[N - M - 1] % MOD;
        }
        ans %= MOD;
    }

    return ans;
}

void read (T...) (string S, ref T args) {
    import std.conv : to;
    import std.array : split;
    auto buf = S.split;
    foreach (i, ref arg; args) {
        arg = buf[i].to!(typeof(arg));
    }
}

import std.traits : ReturnType, isCallable, Parameters;
import std.meta : AliasSeq;

class SegmentTree (T, alias ope, alias e)
if (
           isCallable!(ope)
        && isCallable!(e)
        && is (ReturnType!(ope) == T)
        && is (ReturnType!(e) == T)
        && is (Parameters!(ope) == AliasSeq!(T, T))
        && is (Parameters!(e) == AliasSeq!())
        )
{
    /* 内部の配列が1-indexedで2冪のセグメントツリー */
    import std.format : format;
    T[] X;
    size_t length;

    /* --- Constructors --- */
    this (size_t length_) {
        adjust_array_length(length_);
        for (size_t i = length; i < 2*length; i++) {
            X[i] = e();
        }
        build();
    }

    import std.range.primitives : isInputRange;
    this (R) (R Range)
    if (isInputRange!(R) && is (ElementType!(R) == T))
    {
        adjust_array_length(walkLength(Range));
        size_t i = length;
        foreach (item; Range) {
            X[i] = item;
            i++;
        }
        while (i < 2*length) { X[i] = e(); i++; }
        build();
    }

    /* --- Functions --- */
    private 
        void adjust_array_length (size_t length_) {
            length = 1;
            while (length <= length_) length *= 2;
            X = new T[](2*length);
        }

        void set_with_no_update (size_t idx, T val)
        in {
            assert(idx < length,
                    format("In function \"set_with_no_update\", idx is out of range. (length = %s idx = %s)", length, idx));
        }
        do {
            X[length+idx] = val;
        }

        void build () {
            for (size_t i = length-1; 1 <= i; i--) {
                X[i] = ope(X[2*i], X[2*i+1]);
            }
        }

    public
        override string toString () {
            string res = "";
            int level = 1;
            while ((2^^(level-1)) < X.length) {
                res ~= format("level: %2s | ", level);
                for (size_t i = (2^^(level-1)); i < (2^^level); i++) {
                    res ~= format("%s%s", X[i], (i == (2^^level)-1 ? "\n" : " "));
                }
                level++;
            }
            return res;
        }

        void set (size_t idx, T val)
        in {
            assert(idx < length,
                    format("In function \"set\", idx is out of range. (length = %s idx = %s)", length, idx));
        }
        do {
            idx += length;
            X[idx] = val;
            while (2 <= idx) {
                idx /= 2;
                X[idx] = ope(X[2*idx], X[2*idx+1]);
            }
        }

        T get (size_t idx)
        in {
            assert(idx < length,
                    format("In function \"get\", idx is out of range. (length = %s idx = %s)", length, idx));
        }
        do {
            idx += length;
            return X[idx];
        }

        T prod (size_t l, size_t r)
        in {
            assert(l < length,
                    format("In function \"prod\", l is out of range. (length = %s l = %s)", length, l));
            assert(r <= length,
                    format("In function \"prod\", r is out of range. (length = %s r = %s)", length, r));
            assert(l <= r,
                    format("In function \"prod\", l < r must be satisfied. (length = %s l = %s, r = %s)", length, l, r));
        }
        do {
            /* Returns all prod O(1) */
            if (l == 0 && r == length) return X[1];
            if (l == r) return e();

            /* Closed interval [l, r] */
            r--;
            l += length, r += length;
            T LeftProd, RightProd;
            LeftProd = RightProd = e();

            while (l <= r) {
                if (l % 2 == 1) {
                    LeftProd = ope(LeftProd, X[l]);
                    l++;
                }
                if (r % 2 == 0) {
                    RightProd = ope(X[r], RightProd);
                    r--;
                }

                l /= 2;
                r /= 2;
            }

            return ope(LeftProd, RightProd);
        }
}

long mod_pow (long a, long x, const long MOD)
in {
    assert(0 <= x, "x must satisfy 0 <= x");
    assert(1 <= MOD, "MOD must satisfy 1 <= MOD");
    assert(MOD <= int.max, "MOD must satisfy MOD*MOD <= long.max");
}
do {
    // normalize
    a %= MOD; a += MOD; a %= MOD;

    long res = 1L;
    long base = a;
    while (0 < x) {
        if (0 < (x&1)) (res *= base) %= MOD;
        (base *= base) %= MOD;
        x >>= 1;
    }

    return res % MOD;
}

// check mod_pow
static assert(__traits(compiles, mod_pow(2, 10, 998244353)));

long mod_inv (const long x, const long MOD)
in {
    import std.format : format;
    assert(1 <= MOD, format("MOD must satisfy 1 <= MOD. Now MOD =  %s.", MOD));
    assert(MOD <= int.max, format("MOD must satisfy MOD*MOD <= long.max. Now MOD = %s.", MOD));
}
do {
    return mod_pow(x, MOD-2, MOD);
}
0