結果
| 問題 |
No.2361 Many String Compare Queries
|
| コンテスト | |
| ユーザー |
遭難者
|
| 提出日時 | 2023-01-13 15:27:48 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
RE
(最新)
AC
(最初)
|
| 実行時間 | - |
| コード長 | 7,593 bytes |
| コンパイル時間 | 250 ms |
| コンパイル使用メモリ | 81,772 KB |
| 実行使用メモリ | 115,512 KB |
| 最終ジャッジ日時 | 2024-06-30 20:39:55 |
| 合計ジャッジ時間 | 11,168 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 11 RE * 3 |
ソースコード
import sys
class SegTree:
def __init__(self, d, INF):
n = len(d)
n2 = 1
while n2 < n:
n2 <<= 1
self.N2 = n2
self.N = n2 << 1
self.data = [INF for _ in range(self.N)]
idx = self.N2
for i in d:
self.data[idx] = i
idx += 1
for i in range(self.N2 - 1, 0, -1):
ii = i << 1
self.data[i] = min(self.data[ii], self.data[ii | 1])
self.INF = INF
def prod(self, a, b, k, l, r):
if r <= a or b <= l:
return self.INF
if a <= l and r <= b:
return self.data[k]
ty = l + r >> 1
ii = k << 1
return min(self.prod(a, b, ii, l, ty), self.prod(a, b, ii | 1, ty, r))
def prod(self, l, r):
return self.prod(l, r, 1, 0, self.N2)
def set(self, i, x):
i += self.N2
self.data[i] = x
while i != 1:
i >>= 1
ii = i << 1
self.data[i] = min(self.data[ii], self.data[ii | 1])
def maxRight0(self, idx, jud, k, l, r):
if self.data[k] >= jud or idx < l:
return -1
if k >= self.N2:
return k - self.N2
ii = k << 1
ty = l + r >> 1
koh = self.maxRight0(idx, jud, ii | 1, ty, r)
if koh != -1:
return koh
return self.maxRight0(idx, jud, ii, l, ty)
def maxRight(self, idx, jud):
return self.maxRight0(idx, jud, 1, 0, self.N2)
def minLeft0(self, idx, jud, k, l, r):
if self.data[k] >= jud or r <= idx:
return self.INF
if k >= self.N2:
return k - self.N2
ii = k << 1
ty = l + r >> 1
koh = self.minLeft0(idx, jud, ii, l, ty)
if koh != self.INF:
return koh
return self.minLeft0(idx, jud, ii | 1, ty, r)
def minLeft(self, idx, jud):
return self.minLeft0(idx, jud, 1, 0, self.N2)
class string:
def sa_is(s, upper):
n = len(s)
if n == 0:
return []
if n == 1:
return [0]
if n == 2:
if (s[0] < s[1]):
return [0, 1]
else:
return [1, 0]
sa = [0]*n
ls = [0]*n
for i in range(n-2, -1, -1):
ls[i] = ls[i+1] if (s[i] == s[i+1]) else (s[i] < s[i+1])
sum_l = [0]*(upper+1)
sum_s = [0]*(upper+1)
for i in range(n):
if not (ls[i]):
sum_s[s[i]] += 1
else:
sum_l[s[i]+1] += 1
for i in range(upper+1):
sum_s[i] += sum_l[i]
if i < upper:
sum_l[i+1] += sum_s[i]
def induce(lms):
for i in range(n):
sa[i] = -1
buf = sum_s[:]
for d in lms:
if d == n:
continue
sa[buf[s[d]]] = d
buf[s[d]] += 1
buf = sum_l[:]
sa[buf[s[n-1]]] = n-1
buf[s[n-1]] += 1
for i in range(n):
v = sa[i]
if v >= 1 and not (ls[v-1]):
sa[buf[s[v-1]]] = v-1
buf[s[v-1]] += 1
buf = sum_l[:]
for i in range(n-1, -1, -1):
v = sa[i]
if v >= 1 and ls[v-1]:
buf[s[v-1]+1] -= 1
sa[buf[s[v-1]+1]] = v-1
lms_map = [-1]*(n+1)
m = 0
for i in range(1, n):
if not (ls[i-1]) and ls[i]:
lms_map[i] = m
m += 1
lms = []
for i in range(1, n):
if not (ls[i-1]) and ls[i]:
lms.append(i)
induce(lms)
if m:
sorted_lms = []
for v in sa:
if lms_map[v] != -1:
sorted_lms.append(v)
rec_s = [0]*m
rec_upper = 0
rec_s[lms_map[sorted_lms[0]]] = 0
for i in range(1, m):
l = sorted_lms[i-1]
r = sorted_lms[i]
end_l = lms[lms_map[l]+1] if (lms_map[l]+1 < m) else n
end_r = lms[lms_map[r]+1] if (lms_map[r]+1 < m) else n
same = True
if end_l-l != end_r-r:
same = False
else:
while (l < end_l):
if s[l] != s[r]:
break
l += 1
r += 1
if (l == n) or (s[l] != s[r]):
same = False
if not (same):
rec_upper += 1
rec_s[lms_map[sorted_lms[i]]] = rec_upper
rec_sa = string.sa_is(rec_s, rec_upper)
for i in range(m):
sorted_lms[i] = lms[rec_sa[i]]
induce(sorted_lms)
return sa
def suffix_array_upper(s, upper):
assert 0 <= upper
for d in s:
assert 0 <= d and d <= upper
return string.sa_is(s, upper)
def suffix_array(s):
n = len(s)
if type(s) == str:
s2 = [ord(i) for i in s]
return string.sa_is(s2, 255)
else:
idx = list(range(n))
idx.sort(key=lambda x: s[x])
s2 = [0]*n
now = 0
for i in range(n):
if (i & s[idx[i-1]] != s[idx[i]]):
now += 1
s2[idx[i]] = now
return string.sa_is(s2, now)
def lcp_array(s, sa):
n = len(s)
assert n >= 1
rnk = [0]*n
for i in range(n):
rnk[sa[i]] = i
lcp = [0]*(n-1)
h = 0
for i in range(n):
if h > 0:
h -= 1
if rnk[i] == 0:
continue
j = sa[rnk[i]-1]
while (j+h < n and i+h < n):
if s[j+h] != s[i+h]:
break
h += 1
lcp[rnk[i]-1] = h
return lcp
def z_algorithm(s):
n = len(s)
if n == 0:
return []
z = [0]*n
i = 1
j = 0
while (i < n):
z[i] = 0 if (j+z[j] <= i) else min(j+z[j]-i, z[i-j])
while ((i+z[i] < n) and (s[z[i]] == s[i+z[i]])):
z[i] += 1
if (j+z[j] < i+z[i]):
j = i
i += 1
z[0] = n
return z
def main():
input = sys.stdin.readline
n, q = map(int, input().split())
n1 = n - 1
s = input()
a = string.suffix_array(s)
b = string.lcp_array(s, a)
a.pop(0)
b.pop(0)
a_inv = [0 for _ in range(n)]
for i in range(n):
a_inv[a[i]] = i
len = [0 for _ in range(n)]
for i in range(n):
len[i] = n - a[i]
for i in range(1, n):
len[i] += len[i - 1]
seg = SegTree(b, n1)
p = [0 for _ in range(n1)]
for i in range(n1 - 1, -1, -1):
l = seg.minLeft(i + 1, b[i] + 1)
p[i] = b[i] * (l - i)
if l != n1:
p[i] += p[l]
for _ in range(q):
ll, rr = map(int, input().split())
ll -= 1
l, x = a_inv[ll], rr - ll
ans1, ans2 = 1, 0
if l != 0:
z = seg.maxRight(l - 1, x)
if z != -1:
ans2 += len[z]
ans1 += l - 1 - z
if l != n1:
y = seg.minLeft(l, x)
if y != n1:
ans2 += p[y]
ans1 += y - l
print(ans1 * (x - 1) + ans2)
if __name__ == "__main__":
main()
遭難者