結果
問題 |
No.3239 Omnibus
|
ユーザー |
![]() |
提出日時 | 2025-08-14 22:36:17 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 6,477 bytes |
コンパイル時間 | 366 ms |
コンパイル使用メモリ | 82,888 KB |
実行使用メモリ | 192,760 KB |
最終ジャッジ日時 | 2025-08-14 22:36:26 |
合計ジャッジ時間 | 8,857 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 4 TLE * 1 -- * 27 |
ソースコード
class Set: class Node: __slots__ = ['value', 'sum', 'left', 'right', 'bias', 'height', 'size'] def __init__(self, value): self.value = value self.sum = value self.left = None self.right = None self.bias = 0 self.height = 1 self.size = 1 def left_height(self): return self.left.height if self.left else 0 def right_height(self): return self.right.height if self.right else 0 def left_size(self): return self.left.size if self.left else 0 def right_size(self): return self.right.size if self.right else 0 def __init__(self): self.root = None def height_of(self, node): if not node: return 0 return max(node.left_height(), node.right_height()) + 1 def size_of(self, node): if not node: return 0 return node.left_size() + node.right_size() + 1 def sum_of(self, node): if not node: return 0 return node.sum def update(self, node): if not node: return node.height = self.height_of(node) node.size = self.size_of(node) node.bias = node.left_height() - node.right_height() node.sum = self.sum_of(node.left) + self.sum_of(node.right) + node.value def rotate_left(self, node): r = node.right node.right = r.left r.left = node self.update(node) self.update(r) return r def rotate_right(self, node): l = node.left node.left = l.right l.right = node self.update(node) self.update(l) return l def balance(self, node): if not node: return node if node.bias >= 2: if node.left and node.left.bias < 0: node.left = self.rotate_left(node.left) return self.rotate_right(node) if node.bias <= -2: if node.right and node.right.bias > 0: node.right = self.rotate_right(node.right) return self.rotate_left(node) return node def add_rec(self, cur, value): if not cur: return self.Node(value) if value < cur.value: cur.left = self.add_rec(cur.left, value) else: cur.right = self.add_rec(cur.right, value) self.update(cur) return self.balance(cur) def get_max_node(self, node): while node.right: node = node.right return node def remove_rec(self, cur, value): if not cur: return None if value == cur.value: if cur.left and cur.right: mx = self.get_max_node(cur.left) cur.value = mx.value cur.left = self.remove_rec(cur.left, mx.value) else: nxt = cur.left if cur.left else cur.right return nxt elif value < cur.value: cur.left = self.remove_rec(cur.left, value) else: cur.right = self.remove_rec(cur.right, value) self.update(cur) return self.balance(cur) def lower_bound_rec(self, cur, value, acc): if not cur: return acc if value <= cur.value: return self.lower_bound_rec(cur.left, value, acc) else: return self.lower_bound_rec(cur.right, value, acc + cur.left_size() + 1) def prefix_sum_rec(self, cur, r): if not cur: return 0 left_sz = cur.left_size() if r <= left_sz: return self.prefix_sum_rec(cur.left, r) elif r == left_sz + 1: return self.sum_of(cur.left) + cur.value else: return self.sum_of(cur.left) + cur.value + self.prefix_sum_rec(cur.right, r - left_sz - 1) def size(self): return self.size_of(self.root) def add(self, value): self.root = self.add_rec(self.root, value) def remove(self, value): self.root = self.remove_rec(self.root, value) def lower_bound(self, value): return self.lower_bound_rec(self.root, value, 0) def prefix_sum(self, r): if not self.root or r <= 0: return 0 if r >= self.size() + 1: return self.root.sum return self.prefix_sum_rec(self.root, r) def encode(s, start): return (ord(s[start]) - ord('a')) * 26 * 26 + (ord(s[start + 1]) - ord('a')) * 26 + (ord(s[start + 2]) - ord('a')) def main(): import sys input = sys.stdin.readline output = sys.stdout.write data = input().split() N = int(data[0]) Q = int(data[1]) S = input().strip() state = list(S) indices = [None] * (26 * 26 * 26) for i in range(N - 2): code = encode(state, i) if indices[code] is None: indices[code] = Set() indices[code].add(i + 1) for _ in range(Q): data = input().split() q = int(data[0]) if q == 1: k = int(data[1]) x = data[2].strip() k -= 1 for j in range(k - 2, k + 1): if j < 0 or j >= N - 2: continue c = encode(state, j) if indices[c] is not None: indices[c].remove(j + 1) state[k] = x for j in range(k - 2, k + 1): if j < 0 or j >= N - 2: continue c = encode(state, j) if indices[c] is None: indices[c] = Set() indices[c].add(j + 1) elif q == 2: l = int(data[1]) r = int(data[2]) a = data[3].strip() code = encode(a, 0) if indices[code] is None: output("0\n") else: p = indices[code].lower_bound(r - 1) s = indices[code].lower_bound(l) ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s) ans -= (p - s) * (l - 1) output(f"{ans}\n") if __name__ == "__main__": main()