結果

問題 No.3239 Omnibus
ユーザー Nauclhlt🪷
提出日時 2025-08-14 22:27:39
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 485 ms / 10,000 ms
コード長 5,867 bytes
コンパイル時間 2,167 ms
コンパイル使用メモリ 203,036 KB
実行使用メモリ 16,768 KB
最終ジャッジ日時 2025-08-14 23:56:48
合計ジャッジ時間 12,985 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

using ll = long long;

template <typename T>
class Set {
    struct Node {
        T value, sum;
        Node* left;
        Node* right;
        int bias, height, size;
        Node(const T& v)
            : value(v), sum(v), left(nullptr), right(nullptr),
              bias(0), height(1), size(1) {}
        int leftHeight() const { return left ? left->height : 0; }
        int rightHeight() const { return right ? right->height : 0; }
        int leftSize() const { return left ? left->size : 0; }
        int rightSize() const { return right ? right->size : 0; }
    };
    Node* root = nullptr;

    int heightOf(Node* n) { return n ? max(n->leftHeight(), n->rightHeight()) + 1 : 0; }
    int sizeOf(Node* n) { return n ? n->leftSize() + n->rightSize() + 1 : 0; }
    T sumOf(Node* n) { return n ? n->sum : T(0); }

    void update(Node* n) {
        if (!n) return;
        n->height = heightOf(n);
        n->size = sizeOf(n);
        n->bias = n->leftHeight() - n->rightHeight();
        n->sum = sumOf(n->left) + sumOf(n->right) + n->value;
    }

    Node* rotateLeft(Node* n) {
        Node* r = n->right;
        n->right = r->left;
        r->left = n;
        update(n);
        update(r);
        return r;
    }
    Node* rotateRight(Node* n) {
        Node* l = n->left;
        n->left = l->right;
        l->right = n;
        update(n);
        update(l);
        return l;
    }
    Node* balance(Node* n) {
        if (!n) return n;
        if (n->bias >= 2) {
            if (n->left && n->left->bias < 0)
                n->left = rotateLeft(n->left);
            return rotateRight(n);
        }
        if (n->bias <= -2) {
            if (n->right && n->right->bias > 0)
                n->right = rotateRight(n->right);
            return rotateLeft(n);
        }
        return n;
    }
    Node* addRec(Node* cur, const T& v) {
        if (!cur) return new Node(v);
        if (v < cur->value) cur->left = addRec(cur->left, v);
        else cur->right = addRec(cur->right, v);
        update(cur);
        return balance(cur);
    }
    Node* getMaxNode(Node* n) { while (n->right) n = n->right; return n; }
    Node* removeRec(Node* cur, const T& v) {
        if (!cur) return nullptr;
        if (v == cur->value) {
            if (cur->left && cur->right) {
                Node* mx = getMaxNode(cur->left);
                cur->value = mx->value;
                cur->left = removeRec(cur->left, mx->value);
            } else {
                Node* nxt = cur->left ? cur->left : cur->right;
                delete cur;
                return nxt;
            }
        } else if (v < cur->value) {
            cur->left = removeRec(cur->left, v);
        } else {
            cur->right = removeRec(cur->right, v);
        }
        update(cur);
        return balance(cur);
    }
    int lowerBoundRec(Node* cur, const T& v, int acc) {
        if (!cur) return acc;
        if (v <= cur->value)
            return lowerBoundRec(cur->left, v, acc);
        else
            return lowerBoundRec(cur->right, v, acc + cur->leftSize() + 1);
    }
    T prefixSumRec(Node* cur, int r) {
        if (!cur) return T(0);
        int leftSz = cur->leftSize();
        if (r <= leftSz)
            return prefixSumRec(cur->left, r);
        else if (r == leftSz + 1)
            return sumOf(cur->left) + cur->value;
        else
            return sumOf(cur->left) + cur->value + prefixSumRec(cur->right, r - leftSz - 1);
    }

public:
    int size() { return sizeOf(root); }
    void add(const T& v) { root = addRec(root, v); }
    void remove(const T& v) { root = removeRec(root, v); }
    int lowerBound(const T& v) { return lowerBoundRec(root, v, 0); }
    T prefixSum(int r) {
        if (!root || r <= 0) return T(0);
        if (r >= size() + 1) return root->sum;
        return prefixSumRec(root, r);
    }
};

// ---------------- main logic ----------------
inline int encode(const string& s, int start) {
    return (s[start] - 'a') * 26 * 26 + (s[start + 1] - 'a') * 26 + (s[start + 2] - 'a');
}
inline int encode(const vector<char>& s, int start) {
    return (s[start] - 'a') * 26 * 26 + (s[start + 1] - 'a') * 26 + (s[start + 2] - 'a');
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int N, Q;
    cin >> N >> Q;
    string S;
    cin >> S;
    vector<char> state(S.begin(), S.end());

    vector<Set<ll>*> indices(26 * 26 * 26, nullptr);

    for (int i = 0; i < N - 2; i++) {
        int code = encode(S, i);
        if (!indices[code]) indices[code] = new Set<ll>();
        indices[code]->add(i + 1);
    }

    for (int qi = 0; qi < Q; qi++) {
        int q;
        cin >> q;
        if (q == 1) {
            int k;
            char x;
            cin >> k >> x;
            k--;
            for (int j = k - 2; j <= k; j++) {
                if (j < 0 || j >= N - 2) continue;
                int c = encode(state, j);
                if (indices[c]) indices[c]->remove(j + 1);
            }
            state[k] = x;
            for (int j = k - 2; j <= k; j++) {
                if (j < 0 || j >= N - 2) continue;
                int c = encode(state, j);
                if (!indices[c]) indices[c] = new Set<long long>();
                indices[c]->add(j + 1);
            }
        } else if (q == 2) {
            int l, r;
            string a;
            cin >> l >> r >> a;
            int code = encode(a, 0);
            if (!indices[code]) {
                cout << 0 << "\n";
            } else {
                int p = indices[code]->lowerBound(r - 1);
                int s = indices[code]->lowerBound(l);
                ll ans = indices[code]->prefixSum(p) - indices[code]->prefixSum(s);
                ans -= 1LL * (p - s) * (l - 1);
                cout << ans << "\n";
            }
        }
    }
    return 0;
}
0