結果
| 問題 | No.3239 Omnibus |
| コンテスト | |
| ユーザー |
Nauclhlt🪷
|
| 提出日時 | 2025-08-14 22:54:33 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
RE
|
| 実行時間 | - |
| コード長 | 7,806 bytes |
| 記録 | |
| コンパイル時間 | 329 ms |
| コンパイル使用メモリ | 82,048 KB |
| 実行使用メモリ | 130,048 KB |
| 最終ジャッジ日時 | 2025-08-14 22:54:57 |
| 合計ジャッジ時間 | 14,637 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | RE * 1 |
| other | WA * 2 RE * 30 |
ソースコード
import sys
sys.setrecursionlimit(1 << 25)
class Set:
def __init__(self):
self.root = None
class Node:
__slots__ = ['value', 'sum', 'size', 'height', 'left', 'right']
def __init__(self, value):
self.value = value
self.sum = value
self.size = 1
self.height = 1
self.left = None
self.right = None
# ---------------- ヘルパー ----------------
def _update(self, node):
if not node:
return
lh = node.left.height if node.left else 0
rh = node.right.height if node.right else 0
ls = node.left.size if node.left else 0
rs = node.right.size if node.right else 0
ls_sum = node.left.sum if node.left else 0
rs_sum = node.right.sum if node.right else 0
node.height = max(lh, rh) + 1
node.size = ls + rs + 1
node.sum = ls_sum + rs_sum + node.value
def _rotate_left(self, node):
r = node.right
node.right = r.left
r.left = node
self._update(node)
self._update(r)
return r
def _rotate_right(self, node):
l = node.left
node.left = l.right
l.right = node
self._update(node)
self._update(l)
return l
def _balance(self, node):
self._update(node)
balance = (node.left.height if node.left else 0) - (node.right.height if node.right else 0)
if balance >= 2:
if node.left and (node.left.left.height if node.left.left else 0) < (node.left.right.height if node.left.right else 0):
node.left = self._rotate_left(node.left)
return self._rotate_right(node)
if balance <= -2:
if node.right and (node.right.right.height if node.right.right else 0) < (node.right.left.height if node.right.left else 0):
node.right = self._rotate_right(node.right)
return self._rotate_left(node)
return node
# ---------------- add (非再帰) ----------------
def add(self, value):
if not self.root:
self.root = self.Node(value)
return
path = []
node = self.root
while True:
path.append(node)
if value < node.value:
if node.left is None:
node.left = self.Node(value)
break
node = node.left
else:
if node.right is None:
node.right = self.Node(value)
break
node = node.right
# 下から上へバランス
for i in reversed(range(len(path))):
path[i] = self._balance(path[i])
if i == 0:
self.root = path[i]
else:
parent = path[i - 1]
if parent.left and parent.left.value == path[i].value or value < parent.value:
parent.left = path[i]
else:
parent.right = path[i]
# ---------------- remove (非再帰) ----------------
def remove(self, value):
node = self.root
path = []
parent = None
is_left = True
# 探索
while node and node.value != value:
path.append(node)
parent = node
if value < node.value:
node = node.left
is_left = True
else:
node = node.right
is_left = False
if not node:
return # 存在しない
# 2子の場合
if node.left and node.right:
# 左部分木の最大を探す
max_left = node.left
max_path = [node]
while max_left.right:
max_path.append(max_left)
max_left = max_left.right
node.value = max_left.value
# max_left を削除
node = max_left
path += max_path[1:]
parent = path[-1] if path else None
is_left = parent.left == node if parent else True
# 1子または0子
child = node.left if node.left else node.right
if not path:
self.root = child
else:
p = path[-1]
if p.left == node:
p.left = child
else:
p.right = child
# バランス
for n in reversed(path):
n = self._balance(n)
self.root = self._balance(self.root)
# ---------------- lower_bound ----------------
def lower_bound(self, value):
node = self.root
res = 0
while node:
if value <= node.value:
node = node.left
else:
res += (node.left.size if node.left else 0) + 1
node = node.right
return res
# ---------------- prefix_sum ----------------
def prefix_sum(self, r):
node = self.root
res = 0
while node and r > 0:
left_size = node.left.size if node.left else 0
left_sum = node.left.sum if node.left else 0
if r <= left_size:
node = node.left
elif r == left_size + 1:
res += left_sum + node.value
break
else:
res += left_sum + node.value
r -= left_size + 1
node = node.right
return res
# Precompute powers for encoding
POW26 = [1, 26, 26*26]
def encode(s, start):
if isinstance(s, str):
return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97)
elif isinstance(s, (list, bytearray)):
return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97)
else:
raise TypeError("Unsupported type for encode")
def main():
input = sys.stdin.buffer.readline
output = sys.stdout.buffer.write
data = input().split()
N = int(data[0])
Q = int(data[1])
S = input().decode().strip()
state = bytearray(S, 'ascii')
indices = [None] * (26 * 26 * 26)
for i in range(N - 2):
code = encode(state, i)
if indices[code] is None:
indices[code] = Set()
indices[code].add(i + 1)
out_lines = []
for _ in range(Q):
data = input().split()
q = int(data[0])
if q == 1:
k = int(data[1]) - 1
x = data[2].decode()
# 境界チェックを厳密に行う
start = max(0, k - 2)
end = min(N - 3, k) # N-2までのインデックスなのでN-3が最大
for j in range(start, end + 1):
c = encode(state, j)
if indices[c] is not None:
indices[c].remove(j + 1)
state[k] = ord(x[0])
for j in range(start, end + 1):
c = encode(state, j)
if indices[c] is None:
indices[c] = Set()
indices[c].add(j + 1)
elif q == 2:
l = int(data[1])
r = int(data[2])
a = data[3].decode() # 修正: data[3]から取得
code = encode(a, 0)
if indices[code] is None:
out_lines.append(b"0\n")
else:
p = indices[code].lower_bound(r - 1)
s = indices[code].lower_bound(l)
ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
ans -= (p - s) * (l - 1)
out_lines.append(f"{ans}\n".encode())
output(b"".join(out_lines))
if __name__ == "__main__":
main()
Nauclhlt🪷