結果
| 問題 | No.3239 Omnibus |
| コンテスト | |
| ユーザー |
Nauclhlt🪷
|
| 提出日時 | 2025-08-14 22:47:19 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 7,172 bytes |
| 記録 | |
| コンパイル時間 | 391 ms |
| コンパイル使用メモリ | 82,276 KB |
| 実行使用メモリ | 150,272 KB |
| 最終ジャッジ日時 | 2025-08-14 22:48:16 |
| 合計ジャッジ時間 | 49,631 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 4 WA * 26 TLE * 1 -- * 1 |
ソースコード
import sys
sys.setrecursionlimit(1 << 25)
class Set:
__slots__ = ['root']
class Node:
__slots__ = ['value', 'sum', 'left', 'right', 'bias', 'height', 'size']
def __init__(self, value):
self.value = value
self.sum = value
self.left = None
self.right = None
self.bias = 0
self.height = 1
self.size = 1
def __init__(self):
self.root = None
def update(self, node):
if not node:
return
lh = node.left.height if node.left else 0
rh = node.right.height if node.right else 0
ls = node.left.size if node.left else 0
rs = node.right.size if node.right else 0
ls_sum = node.left.sum if node.left else 0
rs_sum = node.right.sum if node.right else 0
node.height = max(lh, rh) + 1
node.size = ls + rs + 1
node.bias = lh - rh
node.sum = ls_sum + rs_sum + node.value
def rotate_left(self, node):
r = node.right
node.right = r.left
r.left = node
self.update(node)
self.update(r)
return r
def rotate_right(self, node):
l = node.left
node.left = l.right
l.right = node
self.update(node)
self.update(l)
return l
def balance(self, node):
if not node:
return node
if node.bias >= 2:
if node.left and node.left.bias < 0:
node.left = self.rotate_left(node.left)
return self.rotate_right(node)
if node.bias <= -2:
if node.right and node.right.bias > 0:
node.right = self.rotate_right(node.right)
return self.rotate_left(node)
return node
def add(self, value):
if not self.root:
self.root = self.Node(value)
return
path = []
cur = self.root
while True:
path.append(cur)
if value < cur.value:
if not cur.left:
cur.left = self.Node(value)
break
cur = cur.left
else:
if not cur.right:
cur.right = self.Node(value)
break
cur = cur.right
for node in reversed(path):
self.update(node)
self.root = self.balance(self.root)
def get_max_node(self, node):
while node.right:
node = node.right
return node
def remove(self, value):
if not self.root:
return
path = []
cur = self.root
while cur:
path.append(cur)
if value == cur.value:
break
elif value < cur.value:
cur = cur.left
else:
cur = cur.right
else:
return
node = path[-1]
if node.left and node.right:
mx = self.get_max_node(node.left)
node.value = mx.value
value = mx.value
cur = node.left
path.append(cur)
while cur.right:
path.append(cur.right)
cur = cur.right
node = cur
parent = path[-2] if len(path) > 1 else None
child = node.left if node.left else node.right
if parent:
if parent.left == node:
parent.left = child
else:
parent.right = child
else:
self.root = child
for n in reversed(path[:-1]):
self.update(n)
if len(path) > 1 and n == path[-2]:
self.root = self.balance(n)
def lower_bound(self, value):
if not self.root:
return 0
res = 0
cur = self.root
while cur:
if value <= cur.value:
cur = cur.left
else:
res += (cur.left.size if cur.left else 0) + 1
cur = cur.right
return res
def prefix_sum(self, r):
if not self.root or r <= 0:
return 0
if r >= self.root.size + 1:
return self.root.sum
res = 0
cur = self.root
while cur:
left_sz = cur.left.size if cur.left else 0
if r <= left_sz:
cur = cur.left
elif r == left_sz + 1:
res += (cur.left.sum if cur.left else 0) + cur.value
break
else:
res += (cur.left.sum if cur.left else 0) + cur.value
r -= left_sz + 1
cur = cur.right
return res
# Precompute powers for encoding
POW26 = [1, 26, 26*26]
def encode(s, start):
if isinstance(s, str):
return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97)
elif isinstance(s, (list, bytearray)):
return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97)
else:
raise TypeError("Unsupported type for encode")
def main():
input = sys.stdin.buffer.readline
output = sys.stdout.buffer.write
data = input().split()
N = int(data[0])
Q = int(data[1])
S = input().decode().strip()
state = bytearray(S, 'ascii')
indices = [None] * (26 * 26 * 26)
for i in range(N - 2):
code = encode(state, i)
if indices[code] is None:
indices[code] = Set()
indices[code].add(i + 1)
out_lines = []
for _ in range(Q):
data = input().split()
q = int(data[0])
if q == 1:
k = int(data[1]) - 1
x = data[2].decode()
# 境界チェックを厳密に行う
start = max(0, k - 2)
end = min(N - 3, k) # N-2までのインデックスなのでN-3が最大
for j in range(start, end + 1):
c = encode(state, j)
if indices[c] is not None:
indices[c].remove(j + 1)
state[k] = ord(x[0])
for j in range(start, end + 1):
c = encode(state, j)
if indices[c] is None:
indices[c] = Set()
indices[c].add(j + 1)
elif q == 2:
l = int(data[1])
r = int(data[2])
a = data[3].decode() # 修正: data[3]から取得
code = encode(a, 0)
if indices[code] is None:
out_lines.append(b"0\n")
else:
p = indices[code].lower_bound(r - 1)
s = indices[code].lower_bound(l)
ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
ans -= (p - s) * (l - 1)
out_lines.append(f"{ans}\n".encode())
output(b"".join(out_lines))
if __name__ == "__main__":
main()
Nauclhlt🪷