結果
問題 |
No.3239 Omnibus
|
ユーザー |
![]() |
提出日時 | 2025-08-14 22:27:39 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 485 ms / 10,000 ms |
コード長 | 5,867 bytes |
コンパイル時間 | 2,167 ms |
コンパイル使用メモリ | 203,036 KB |
実行使用メモリ | 16,768 KB |
最終ジャッジ日時 | 2025-08-14 23:56:48 |
合計ジャッジ時間 | 12,985 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 33 |
ソースコード
#include <bits/stdc++.h> using namespace std; using ll = long long; template <typename T> class Set { struct Node { T value, sum; Node* left; Node* right; int bias, height, size; Node(const T& v) : value(v), sum(v), left(nullptr), right(nullptr), bias(0), height(1), size(1) {} int leftHeight() const { return left ? left->height : 0; } int rightHeight() const { return right ? right->height : 0; } int leftSize() const { return left ? left->size : 0; } int rightSize() const { return right ? right->size : 0; } }; Node* root = nullptr; int heightOf(Node* n) { return n ? max(n->leftHeight(), n->rightHeight()) + 1 : 0; } int sizeOf(Node* n) { return n ? n->leftSize() + n->rightSize() + 1 : 0; } T sumOf(Node* n) { return n ? n->sum : T(0); } void update(Node* n) { if (!n) return; n->height = heightOf(n); n->size = sizeOf(n); n->bias = n->leftHeight() - n->rightHeight(); n->sum = sumOf(n->left) + sumOf(n->right) + n->value; } Node* rotateLeft(Node* n) { Node* r = n->right; n->right = r->left; r->left = n; update(n); update(r); return r; } Node* rotateRight(Node* n) { Node* l = n->left; n->left = l->right; l->right = n; update(n); update(l); return l; } Node* balance(Node* n) { if (!n) return n; if (n->bias >= 2) { if (n->left && n->left->bias < 0) n->left = rotateLeft(n->left); return rotateRight(n); } if (n->bias <= -2) { if (n->right && n->right->bias > 0) n->right = rotateRight(n->right); return rotateLeft(n); } return n; } Node* addRec(Node* cur, const T& v) { if (!cur) return new Node(v); if (v < cur->value) cur->left = addRec(cur->left, v); else cur->right = addRec(cur->right, v); update(cur); return balance(cur); } Node* getMaxNode(Node* n) { while (n->right) n = n->right; return n; } Node* removeRec(Node* cur, const T& v) { if (!cur) return nullptr; if (v == cur->value) { if (cur->left && cur->right) { Node* mx = getMaxNode(cur->left); cur->value = mx->value; cur->left = removeRec(cur->left, mx->value); } else { Node* nxt = cur->left ? cur->left : cur->right; delete cur; return nxt; } } else if (v < cur->value) { cur->left = removeRec(cur->left, v); } else { cur->right = removeRec(cur->right, v); } update(cur); return balance(cur); } int lowerBoundRec(Node* cur, const T& v, int acc) { if (!cur) return acc; if (v <= cur->value) return lowerBoundRec(cur->left, v, acc); else return lowerBoundRec(cur->right, v, acc + cur->leftSize() + 1); } T prefixSumRec(Node* cur, int r) { if (!cur) return T(0); int leftSz = cur->leftSize(); if (r <= leftSz) return prefixSumRec(cur->left, r); else if (r == leftSz + 1) return sumOf(cur->left) + cur->value; else return sumOf(cur->left) + cur->value + prefixSumRec(cur->right, r - leftSz - 1); } public: int size() { return sizeOf(root); } void add(const T& v) { root = addRec(root, v); } void remove(const T& v) { root = removeRec(root, v); } int lowerBound(const T& v) { return lowerBoundRec(root, v, 0); } T prefixSum(int r) { if (!root || r <= 0) return T(0); if (r >= size() + 1) return root->sum; return prefixSumRec(root, r); } }; // ---------------- main logic ---------------- inline int encode(const string& s, int start) { return (s[start] - 'a') * 26 * 26 + (s[start + 1] - 'a') * 26 + (s[start + 2] - 'a'); } inline int encode(const vector<char>& s, int start) { return (s[start] - 'a') * 26 * 26 + (s[start + 1] - 'a') * 26 + (s[start + 2] - 'a'); } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int N, Q; cin >> N >> Q; string S; cin >> S; vector<char> state(S.begin(), S.end()); vector<Set<ll>*> indices(26 * 26 * 26, nullptr); for (int i = 0; i < N - 2; i++) { int code = encode(S, i); if (!indices[code]) indices[code] = new Set<ll>(); indices[code]->add(i + 1); } for (int qi = 0; qi < Q; qi++) { int q; cin >> q; if (q == 1) { int k; char x; cin >> k >> x; k--; for (int j = k - 2; j <= k; j++) { if (j < 0 || j >= N - 2) continue; int c = encode(state, j); if (indices[c]) indices[c]->remove(j + 1); } state[k] = x; for (int j = k - 2; j <= k; j++) { if (j < 0 || j >= N - 2) continue; int c = encode(state, j); if (!indices[c]) indices[c] = new Set<long long>(); indices[c]->add(j + 1); } } else if (q == 2) { int l, r; string a; cin >> l >> r >> a; int code = encode(a, 0); if (!indices[code]) { cout << 0 << "\n"; } else { int p = indices[code]->lowerBound(r - 1); int s = indices[code]->lowerBound(l); ll ans = indices[code]->prefixSum(p) - indices[code]->prefixSum(s); ans -= 1LL * (p - s) * (l - 1); cout << ans << "\n"; } } } return 0; }