結果
| 問題 | No.3239 Omnibus |
| コンテスト | |
| ユーザー |
Nauclhlt🪷
|
| 提出日時 | 2025-08-14 22:51:23 |
| 言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
| 結果 |
AC
|
| 実行時間 | 9,463 ms / 10,000 ms |
| コード長 | 6,297 bytes |
| 記録 | |
| コンパイル時間 | 252 ms |
| コンパイル使用メモリ | 12,800 KB |
| 実行使用メモリ | 66,252 KB |
| 最終ジャッジ日時 | 2025-08-14 23:54:05 |
| 合計ジャッジ時間 | 125,122 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 33 |
ソースコード
import sys
sys.setrecursionlimit(1 << 25)
class Set:
class Node:
__slots__ = ['value', 'sum', 'left', 'right', 'height', 'size']
def __init__(self, value):
self.value = value
self.sum = value
self.left = None
self.right = None
self.height = 1
self.size = 1
def __init__(self):
self.root = None
# ---------------- AVL 基本操作 ----------------
def _height(self, node):
return node.height if node else 0
def _size(self, node):
return node.size if node else 0
def _sum(self, node):
return node.sum if node else 0
def _update(self, node):
if not node: return
node.height = max(self._height(node.left), self._height(node.right)) + 1
node.size = self._size(node.left) + self._size(node.right) + 1
node.sum = self._sum(node.left) + self._sum(node.right) + node.value
def _rotate_left(self, node):
r = node.right
node.right = r.left
r.left = node
self._update(node)
self._update(r)
return r
def _rotate_right(self, node):
l = node.left
node.left = l.right
l.right = node
self._update(node)
self._update(l)
return l
def _balance(self, node):
if not node:
return node
self._update(node)
balance = self._height(node.left) - self._height(node.right)
if balance >= 2:
if self._height(node.left.left) < self._height(node.left.right):
node.left = self._rotate_left(node.left)
return self._rotate_right(node)
if balance <= -2:
if self._height(node.right.right) < self._height(node.right.left):
node.right = self._rotate_right(node.right)
return self._rotate_left(node)
return node
# ---------------- 要素追加 ----------------
def _add(self, node, value):
if not node:
return self.Node(value)
if value < node.value:
node.left = self._add(node.left, value)
else:
node.right = self._add(node.right, value)
return self._balance(node)
def add(self, value):
self.root = self._add(self.root, value)
# ---------------- 要素削除 ----------------
def _get_max(self, node):
while node.right:
node = node.right
return node
def _remove(self, node, value):
if not node:
return None
if value < node.value:
node.left = self._remove(node.left, value)
elif value > node.value:
node.right = self._remove(node.right, value)
else:
if not node.left: return node.right
if not node.right: return node.left
max_left = self._get_max(node.left)
node.value = max_left.value
node.left = self._remove(node.left, max_left.value)
return self._balance(node)
def remove(self, value):
self.root = self._remove(self.root, value)
# ---------------- lower_bound ----------------
def lower_bound(self, value):
node = self.root
res = 0
while node:
if value <= node.value:
node = node.left
else:
res += self._size(node.left) + 1
node = node.right
return res
# ---------------- prefix_sum ----------------
def prefix_sum(self, r):
node = self.root
res = 0
while node and r > 0:
left_sz = self._size(node.left)
if r <= left_sz:
node = node.left
elif r == left_sz + 1:
res += self._sum(node.left) + node.value
break
else:
res += self._sum(node.left) + node.value
r -= left_sz + 1
node = node.right
return res
# Precompute powers for encoding
POW26 = [1, 26, 26*26]
def encode(s, start):
if isinstance(s, str):
return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97)
elif isinstance(s, (list, bytearray)):
return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97)
else:
raise TypeError("Unsupported type for encode")
def main():
input = sys.stdin.buffer.readline
output = sys.stdout.buffer.write
data = input().split()
N = int(data[0])
Q = int(data[1])
S = input().decode().strip()
state = bytearray(S, 'ascii')
indices = [None] * (26 * 26 * 26)
for i in range(N - 2):
code = encode(state, i)
if indices[code] is None:
indices[code] = Set()
indices[code].add(i + 1)
out_lines = []
for _ in range(Q):
data = input().split()
q = int(data[0])
if q == 1:
k = int(data[1]) - 1
x = data[2].decode()
# 境界チェックを厳密に行う
start = max(0, k - 2)
end = min(N - 3, k) # N-2までのインデックスなのでN-3が最大
for j in range(start, end + 1):
c = encode(state, j)
if indices[c] is not None:
indices[c].remove(j + 1)
state[k] = ord(x[0])
for j in range(start, end + 1):
c = encode(state, j)
if indices[c] is None:
indices[c] = Set()
indices[c].add(j + 1)
elif q == 2:
l = int(data[1])
r = int(data[2])
a = data[3].decode() # 修正: data[3]から取得
code = encode(a, 0)
if indices[code] is None:
out_lines.append(b"0\n")
else:
p = indices[code].lower_bound(r - 1)
s = indices[code].lower_bound(l)
ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
ans -= (p - s) * (l - 1)
out_lines.append(f"{ans}\n".encode())
output(b"".join(out_lines))
if __name__ == "__main__":
main()
Nauclhlt🪷