結果

問題 No.1838 Modulo Straight
ユーザー ei1333333ei1333333
提出日時 2022-02-11 22:05:55
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
MLE  
実行時間 -
コード長 4,199 bytes
コンパイル時間 4,827 ms
コンパイル使用メモリ 254,168 KB
最終ジャッジ日時 2025-01-27 21:36:10
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1 MLE * 2
other AC * 5 TLE * 9 MLE * 24
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
#include <atcoder/all>

using namespace std;

struct trie_nod_idx {
  trie_nod_idx *t[2];
  int c[2];

  trie_nod_idx() {
    t[0] = t[1] = NULL;
    c[0] = c[1] = 0;
  }
};

struct trie_idx {
  trie_nod_idx *root;

  trie_idx() {
    root = new trie_nod_idx();
  }

  inline void insert(int &x) {
    bool index;
    trie_nod_idx *curr = root;
    for(int i = 19; i >= 0; i--) {
      index = (((1 << i) & x) != 0);
      curr->c[index]++;
      if(curr->t[index] == NULL) {
        curr->t[index] = new trie_nod_idx();
      }
      curr = curr->t[index];
    }
  }

  inline int greater(int &x) {
    int cnt = 0;
    bool index;
    trie_nod_idx *curr = root;
    for(int i = 19; i >= 0 && curr != NULL; i--) {
      index = (((1 << i) & x) != 0);

      if(!index)
        cnt += curr->c[1];
      curr = curr->t[index];
    }
    return cnt;
  }

  inline int less(int &x) {
    int cnt = 0;
    bool index;
    trie_nod_idx *curr = root;
    for(int i = 19; i >= 0 && curr != NULL; i--) {
      index = (((1 << i) & x) != 0);
      if(index)
        cnt += curr->c[0];
      curr = curr->t[index];
    }
    return cnt;
  }
};

struct trie_nod {
  trie_nod *t[2];
  int c[2];
  vector< int > idx;
  trie_idx removed;
  trie_idx added;

  trie_nod() {
    t[0] = t[1] = NULL;
    c[0] = c[1] = 0;
  }
};

struct trie {
  trie_nod *root = new trie_nod();

  inline int insert_f(int x, int idx) {
    bool index;
    int cnt = 0;
    trie_nod *curr = root;
    for(int i = 19; i >= 0; i--) {
      index = (((1 << i) & x) != 0);
      if(!index)
        cnt += curr->c[1];
      curr->c[index]++;
      if(index) {
        curr->idx.push_back(idx);
      }
      if(curr->t[index] == NULL) {
        curr->t[index] = new trie_nod();
      }
      curr = curr->t[index];
    }
    return cnt;
  }

  inline void insert(int x, int idx) {
    bool index;
    trie_nod *curr = root;
    for(int i = 19; i >= 0; i--) {
      index = (((1 << i) & x) != 0);
      if(index) {
        curr->added.insert(idx);
      }
      if(curr->t[index] == NULL) {
        curr->t[index] = new trie_nod();
      }
      curr = curr->t[index];
    }
  }

  inline void remove(int x, int idx) {
    bool index;
    trie_nod *curr = root;
    for(int i = 19; i >= 0; i--) {
      index = (((1 << i) & x) != 0);
      if(index) {
        curr->removed.insert(idx);
      }
      curr = curr->t[index];
    }
  }

  inline int greater_a(int x, int idx) {
    int cnt = 0;
    bool index;
    trie_nod *curr = root;
    for(int i = 19; i >= 0 && curr != NULL; i--) {
      index = (((1 << i) & x) != 0);
      if(!index) {
        cnt += curr->idx.end() - upper_bound(curr->idx.begin(), curr->idx.end(), idx);
        cnt += curr->added.greater(idx);
        cnt -= curr->removed.greater(idx);
      }
      curr = curr->t[index];
    }
    return cnt;
  }

  inline int greater_b(int x, int idx) {
    int cnt = 0;
    bool index;
    trie_nod *curr = root;
    for(int i = 19; i >= 0 && curr != NULL; i--) {
      index = (((1 << i) & x) != 0);
      if(!index) {
        cnt += lower_bound(curr->idx.begin(), curr->idx.end(), idx) - curr->idx.begin();
        cnt += curr->added.less(idx);
        cnt -= curr->removed.less(idx);
      }
      curr = curr->t[index];
    }
    return cnt;
  }

  inline int between_a(int x, int y, int idx) {
    return greater_a(x - 1, idx) - greater_a(y, idx);
  }

  inline int between_b(int x, int y, int idx) {
    return greater_b(x - 1, idx) - greater_b(y, idx);
  }
};

int main() {
  int M, K;
  cin >> M >> K;
  vector< int > A(M * K);
  for(auto &a: A) cin >> a;

  vector D(M, vector< int >());
  {
    for(int i = 0; i < M * K; i++) {
      D[A[i]].emplace_back(i);
      A[i] += 2 * (D[A[i]].size() - 1) * M;
    }
  }

  int64_t now = 0;
  trie tr;
  for(int i = 0; i < M * K; i++) {
    A[i]++;
    now += tr.insert_f(A[i], i);
  }
  int64_t ret = now;
  for(int i = 0; i < M; i++) {
    for(auto &x: D[i]) {
      int y = A[x] + M;
      now -= tr.between_b(A[x] + 1, y, x);
      now += tr.between_a(A[x], y - 1, x);
      tr.remove(A[x], x);
      tr.insert(y, x);
      A[x] = y;
    }
    ret = min(ret, now);
  }
  cout << ret << "\n";
}
0