結果
| 問題 | No.3239 Omnibus |
| コンテスト | |
| ユーザー |
Nauclhlt🪷
|
| 提出日時 | 2025-08-14 22:36:17 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 6,477 bytes |
| 記録 | |
| コンパイル時間 | 366 ms |
| コンパイル使用メモリ | 82,888 KB |
| 実行使用メモリ | 192,760 KB |
| 最終ジャッジ日時 | 2025-08-14 22:36:26 |
| 合計ジャッジ時間 | 8,857 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 4 TLE * 1 -- * 27 |
ソースコード
class Set:
class Node:
__slots__ = ['value', 'sum', 'left', 'right', 'bias', 'height', 'size']
def __init__(self, value):
self.value = value
self.sum = value
self.left = None
self.right = None
self.bias = 0
self.height = 1
self.size = 1
def left_height(self):
return self.left.height if self.left else 0
def right_height(self):
return self.right.height if self.right else 0
def left_size(self):
return self.left.size if self.left else 0
def right_size(self):
return self.right.size if self.right else 0
def __init__(self):
self.root = None
def height_of(self, node):
if not node:
return 0
return max(node.left_height(), node.right_height()) + 1
def size_of(self, node):
if not node:
return 0
return node.left_size() + node.right_size() + 1
def sum_of(self, node):
if not node:
return 0
return node.sum
def update(self, node):
if not node:
return
node.height = self.height_of(node)
node.size = self.size_of(node)
node.bias = node.left_height() - node.right_height()
node.sum = self.sum_of(node.left) + self.sum_of(node.right) + node.value
def rotate_left(self, node):
r = node.right
node.right = r.left
r.left = node
self.update(node)
self.update(r)
return r
def rotate_right(self, node):
l = node.left
node.left = l.right
l.right = node
self.update(node)
self.update(l)
return l
def balance(self, node):
if not node:
return node
if node.bias >= 2:
if node.left and node.left.bias < 0:
node.left = self.rotate_left(node.left)
return self.rotate_right(node)
if node.bias <= -2:
if node.right and node.right.bias > 0:
node.right = self.rotate_right(node.right)
return self.rotate_left(node)
return node
def add_rec(self, cur, value):
if not cur:
return self.Node(value)
if value < cur.value:
cur.left = self.add_rec(cur.left, value)
else:
cur.right = self.add_rec(cur.right, value)
self.update(cur)
return self.balance(cur)
def get_max_node(self, node):
while node.right:
node = node.right
return node
def remove_rec(self, cur, value):
if not cur:
return None
if value == cur.value:
if cur.left and cur.right:
mx = self.get_max_node(cur.left)
cur.value = mx.value
cur.left = self.remove_rec(cur.left, mx.value)
else:
nxt = cur.left if cur.left else cur.right
return nxt
elif value < cur.value:
cur.left = self.remove_rec(cur.left, value)
else:
cur.right = self.remove_rec(cur.right, value)
self.update(cur)
return self.balance(cur)
def lower_bound_rec(self, cur, value, acc):
if not cur:
return acc
if value <= cur.value:
return self.lower_bound_rec(cur.left, value, acc)
else:
return self.lower_bound_rec(cur.right, value, acc + cur.left_size() + 1)
def prefix_sum_rec(self, cur, r):
if not cur:
return 0
left_sz = cur.left_size()
if r <= left_sz:
return self.prefix_sum_rec(cur.left, r)
elif r == left_sz + 1:
return self.sum_of(cur.left) + cur.value
else:
return self.sum_of(cur.left) + cur.value + self.prefix_sum_rec(cur.right, r - left_sz - 1)
def size(self):
return self.size_of(self.root)
def add(self, value):
self.root = self.add_rec(self.root, value)
def remove(self, value):
self.root = self.remove_rec(self.root, value)
def lower_bound(self, value):
return self.lower_bound_rec(self.root, value, 0)
def prefix_sum(self, r):
if not self.root or r <= 0:
return 0
if r >= self.size() + 1:
return self.root.sum
return self.prefix_sum_rec(self.root, r)
def encode(s, start):
return (ord(s[start]) - ord('a')) * 26 * 26 + (ord(s[start + 1]) - ord('a')) * 26 + (ord(s[start + 2]) - ord('a'))
def main():
import sys
input = sys.stdin.readline
output = sys.stdout.write
data = input().split()
N = int(data[0])
Q = int(data[1])
S = input().strip()
state = list(S)
indices = [None] * (26 * 26 * 26)
for i in range(N - 2):
code = encode(state, i)
if indices[code] is None:
indices[code] = Set()
indices[code].add(i + 1)
for _ in range(Q):
data = input().split()
q = int(data[0])
if q == 1:
k = int(data[1])
x = data[2].strip()
k -= 1
for j in range(k - 2, k + 1):
if j < 0 or j >= N - 2:
continue
c = encode(state, j)
if indices[c] is not None:
indices[c].remove(j + 1)
state[k] = x
for j in range(k - 2, k + 1):
if j < 0 or j >= N - 2:
continue
c = encode(state, j)
if indices[c] is None:
indices[c] = Set()
indices[c].add(j + 1)
elif q == 2:
l = int(data[1])
r = int(data[2])
a = data[3].strip()
code = encode(a, 0)
if indices[code] is None:
output("0\n")
else:
p = indices[code].lower_bound(r - 1)
s = indices[code].lower_bound(l)
ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
ans -= (p - s) * (l - 1)
output(f"{ans}\n")
if __name__ == "__main__":
main()
Nauclhlt🪷