結果
| 問題 |
No.2361 Many String Compare Queries
|
| コンテスト | |
| ユーザー |
遭難者
|
| 提出日時 | 2023-05-08 09:16:49 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
RE
|
| 実行時間 | - |
| コード長 | 6,545 bytes |
| コンパイル時間 | 149 ms |
| コンパイル使用メモリ | 82,396 KB |
| 実行使用メモリ | 98,776 KB |
| 最終ジャッジ日時 | 2024-06-30 20:41:08 |
| 合計ジャッジ時間 | 3,006 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | RE * 2 |
| other | RE * 14 |
ソースコード
import sys
class SparseTable:
def __init__(self, a):
self.n = len(a)
self.log = [0 for _ in range(self.n + 1)]
for i in range(2, self.n + 1):
self.log[i] = self.log[i >> 1] + 1
self.m = self.log[self.n]
self.d = [[] for _ in range(self.m + 1)]
self.d[0] = a[:]
for i in range(self.m):
k = self.n + 1 - (1 << (i + 1))
self.d[i + 1] = [0 for _ in range(k)]
for j in range(k):
self.d[i + 1][j] = min(self.d[i][j], self.d[i][j + (1 << i)])
def min(self, l, r):
bit = self.log[r - l]
return min(self.d[bit][l], self.d[bit][r - (1 << bit)])
def maxRight(self, idx, jud):
ans = idx
for i in range(self.m, -1, -1):
x = ans - (1 << i) + 1
if x >= 0 and self.d[i][x] >= jud:
ans -= 1 << i
return ans
def minLeft(self, idx, jud):
ans = idx
for i in range(self.m, -1, -1):
if ans < len(self.d[i]) and self.d[i][ans] >= jud:
ans += 1 << i
return ans
class string:
def sa_is(s, upper):
n = len(s)
if n == 0:
return []
if n == 1:
return [0]
if n == 2:
if (s[0] < s[1]):
return [0, 1]
else:
return [1, 0]
sa = [0]*n
ls = [0]*n
for i in range(n-2, -1, -1):
ls[i] = ls[i+1] if (s[i] == s[i+1]) else (s[i] < s[i+1])
sum_l = [0]*(upper+1)
sum_s = [0]*(upper+1)
for i in range(n):
if not (ls[i]):
sum_s[s[i]] += 1
else:
sum_l[s[i]+1] += 1
for i in range(upper+1):
sum_s[i] += sum_l[i]
if i < upper:
sum_l[i+1] += sum_s[i]
def induce(lms):
for i in range(n):
sa[i] = -1
buf = sum_s[:]
for d in lms:
if d == n:
continue
sa[buf[s[d]]] = d
buf[s[d]] += 1
buf = sum_l[:]
sa[buf[s[n-1]]] = n-1
buf[s[n-1]] += 1
for i in range(n):
v = sa[i]
if v >= 1 and not (ls[v-1]):
sa[buf[s[v-1]]] = v-1
buf[s[v-1]] += 1
buf = sum_l[:]
for i in range(n-1, -1, -1):
v = sa[i]
if v >= 1 and ls[v-1]:
buf[s[v-1]+1] -= 1
sa[buf[s[v-1]+1]] = v-1
lms_map = [-1]*(n+1)
m = 0
for i in range(1, n):
if not (ls[i-1]) and ls[i]:
lms_map[i] = m
m += 1
lms = []
for i in range(1, n):
if not (ls[i-1]) and ls[i]:
lms.append(i)
induce(lms)
if m:
sorted_lms = []
for v in sa:
if lms_map[v] != -1:
sorted_lms.append(v)
rec_s = [0]*m
rec_upper = 0
rec_s[lms_map[sorted_lms[0]]] = 0
for i in range(1, m):
l = sorted_lms[i-1]
r = sorted_lms[i]
end_l = lms[lms_map[l]+1] if (lms_map[l]+1 < m) else n
end_r = lms[lms_map[r]+1] if (lms_map[r]+1 < m) else n
same = True
if end_l-l != end_r-r:
same = False
else:
while (l < end_l):
if s[l] != s[r]:
break
l += 1
r += 1
if (l == n) or (s[l] != s[r]):
same = False
if not (same):
rec_upper += 1
rec_s[lms_map[sorted_lms[i]]] = rec_upper
rec_sa = string.sa_is(rec_s, rec_upper)
for i in range(m):
sorted_lms[i] = lms[rec_sa[i]]
induce(sorted_lms)
return sa
def suffix_array_upper(s, upper):
assert 0 <= upper
for d in s:
assert 0 <= d and d <= upper
return string.sa_is(s, upper)
def suffix_array(s):
n = len(s)
if type(s) == str:
s2 = [ord(i) for i in s]
return string.sa_is(s2, 255)
else:
idx = list(range(n))
idx.sort(key=lambda x: s[x])
s2 = [0]*n
now = 0
for i in range(n):
if (i & s[idx[i-1]] != s[idx[i]]):
now += 1
s2[idx[i]] = now
return string.sa_is(s2, now)
def lcp_array(s, sa):
n = len(s)
assert n >= 1
rnk = [0]*n
for i in range(n):
rnk[sa[i]] = i
lcp = [0]*(n-1)
h = 0
for i in range(n):
if h > 0:
h -= 1
if rnk[i] == 0:
continue
j = sa[rnk[i]-1]
while (j+h < n and i+h < n):
if s[j+h] != s[i+h]:
break
h += 1
lcp[rnk[i]-1] = h
return lcp
def z_algorithm(s):
n = len(s)
if n == 0:
return []
z = [0]*n
i = 1
j = 0
while (i < n):
z[i] = 0 if (j+z[j] <= i) else min(j+z[j]-i, z[i-j])
while ((i+z[i] < n) and (s[z[i]] == s[i+z[i]])):
z[i] += 1
if (j+z[j] < i+z[i]):
j = i
i += 1
z[0] = n
return z
input = sys.stdin.readline
n, q = map(int, input().split())
n1 = n - 1
s = input()
a = string.suffix_array(s)
b = string.lcp_array(s, a)
a.pop(0)
b.pop(0)
a_inv = [0 for _ in range(n)]
for i in range(n):
a_inv[a[i]] = i
len = [0 for _ in range(n)]
for i in range(n):
len[i] = n - a[i]
for i in range(1, n):
len[i] += len[i - 1]
seg = SparseTable(b)
p = [0 for _ in range(n1)]
for i in range(n1 - 1, -1, -1):
l = seg.minLeft(i + 1, b[i] + 1)
p[i] = b[i] * (l - i)
if l != n1:
p[i] += p[l]
for _ in range(q):
ll, rr = map(int, input().split())
ll -= 1
l, x = a_inv[ll], rr - ll
ans1, ans2 = 1, 0
if l != 0:
z = seg.maxRight(l - 1, x)
if z != -1:
ans2 += len[z]
ans1 += l - 1 - z
if l != n1:
y = seg.minLeft(l, x)
if y != n1:
ans2 += p[y]
ans1 += y - l
print(ans1 * (x - 1) + ans2)
遭難者