結果
問題 |
No.3239 Omnibus
|
ユーザー |
![]() |
提出日時 | 2025-08-14 22:32:59 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 5,244 bytes |
コンパイル時間 | 434 ms |
コンパイル使用メモリ | 81,796 KB |
実行使用メモリ | 191,092 KB |
最終ジャッジ日時 | 2025-08-14 22:33:08 |
合計ジャッジ時間 | 8,581 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 3 WA * 1 TLE * 1 -- * 27 |
ソースコード
class AVLSet: class Node: __slots__ = ['val', 'sum', 'left', 'right', 'bias', 'height', 'size'] def __init__(self, val): self.val = val self.sum = val self.left = None self.right = None self.bias = 0 self.height = 1 self.size = 1 def __init__(self): self.root = None # ===== 基本操作 ===== def add(self, val): self.root = self._add(self.root, val) def remove(self, val): self.root = self._remove(self.root, val) def lower_bound(self, val): node = self.root res = 0 idx = 0 while node: if val <= node.val: node = node.left else: left_size = node.left.size if node.left else 0 idx += left_size + 1 node = node.right return idx def prefix_sum(self, r): return self._prefix_sum(self.root, r) # ===== 内部関数 ===== def _update(self, node): if node is None: return lh = node.left.height if node.left else 0 rh = node.right.height if node.right else 0 node.height = max(lh, rh) + 1 node.bias = lh - rh ls = node.left.size if node.left else 0 rs = node.right.size if node.right else 0 node.size = ls + rs + 1 lsum = node.left.sum if node.left else 0 rsum = node.right.sum if node.right else 0 node.sum = lsum + rsum + node.val def _rotate_left(self, node): r = node.right node.right = r.left r.left = node self._update(node) self._update(r) return r def _rotate_right(self, node): l = node.left node.left = l.right l.right = node self._update(node) self._update(l) return l def _balance(self, node): if node.bias < -1: if node.right and node.right.bias > 0: node.right = self._rotate_right(node.right) return self._rotate_left(node) elif node.bias > 1: if node.left and node.left.bias < 0: node.left = self._rotate_left(node.left) return self._rotate_right(node) return node def _add(self, node, val): if node is None: return AVLSet.Node(val) if val < node.val: node.left = self._add(node.left, val) else: node.right = self._add(node.right, val) self._update(node) return self._balance(node) def _remove(self, node, val): if node is None: return None if val < node.val: node.left = self._remove(node.left, val) elif val > node.val: node.right = self._remove(node.right, val) else: if node.left and node.right: pred = node.left while pred.right: pred = pred.right node.val = pred.val node.left = self._remove(node.left, pred.val) elif node.left: return node.left elif node.right: return node.right else: return None self._update(node) return self._balance(node) def _prefix_sum(self, node, r): if node is None or r <= 0: return 0 left_size = node.left.size if node.left else 0 if r <= left_size: return self._prefix_sum(node.left, r) s = node.left.sum if node.left else 0 if r == left_size + 1: return s + node.val return s + node.val + self._prefix_sum(node.right, r - left_size - 1) # ===== solve 関数 ===== def encode(s, i): return (ord(s[i]) - ord('a')) * 26 * 26 + (ord(s[i+1]) - ord('a')) * 26 + (ord(s[i+2]) - ord('a')) def solve(): import sys input = sys.stdin.readline N, Q = map(int, input().split()) S = input().strip() state = list(S) indices = [None] * (26*26*26) for i in range(N-2): code = encode(S, i) if indices[code] is None: indices[code] = AVLSet() indices[code].add(i+1) for _ in range(Q): tmp = input().split() q = int(tmp[0]) if q == 1: k = int(tmp[1]) - 1 x = tmp[2] for j in range(k-2, k+1): if 0 <= j < N-2: c = encode(state, j) indices[c].remove(j+1) state[k] = x for j in range(k-2, k+1): if 0 <= j < N-2: c = encode(state, j) if indices[c] is None: indices[c] = AVLSet() indices[c].add(j+1) else: l = int(tmp[1]) r = int(tmp[2]) a = tmp[3] code = encode(a, 0) if indices[code] is None: print(0) else: p = indices[code].lower_bound(r) s = indices[code].lower_bound(l) ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s) ans -= (p - s) * (l - 1) print(ans) solve();