結果
問題 | No.2361 Many String Compare Queries |
ユーザー | 遭難者 |
提出日時 | 2023-01-14 18:53:35 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 615 ms / 2,500 ms |
コード長 | 6,760 bytes |
コンパイル時間 | 150 ms |
コンパイル使用メモリ | 82,420 KB |
実行使用メモリ | 127,712 KB |
最終ジャッジ日時 | 2024-06-30 20:40:32 |
合計ジャッジ時間 | 5,701 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 39 ms
54,660 KB |
testcase_01 | AC | 42 ms
55,748 KB |
testcase_02 | AC | 40 ms
54,480 KB |
testcase_03 | AC | 41 ms
54,160 KB |
testcase_04 | AC | 41 ms
54,364 KB |
testcase_05 | AC | 39 ms
54,436 KB |
testcase_06 | AC | 40 ms
54,236 KB |
testcase_07 | AC | 51 ms
62,264 KB |
testcase_08 | AC | 615 ms
127,712 KB |
testcase_09 | AC | 601 ms
124,612 KB |
testcase_10 | AC | 614 ms
125,008 KB |
testcase_11 | AC | 430 ms
117,448 KB |
testcase_12 | AC | 465 ms
115,788 KB |
testcase_13 | AC | 426 ms
114,884 KB |
testcase_14 | AC | 339 ms
123,976 KB |
testcase_15 | AC | 419 ms
123,852 KB |
ソースコード
import sys class SparseTable: def __init__(self, a): self.n = len(a) self.log = [0 for _ in range(self.n + 1)] for i in range(2, self.n + 1): self.log[i] = self.log[i >> 1] + 1 self.m = self.log[self.n] self.d = [[] for _ in range(self.m + 1)] self.d[0] = a[:] for i in range(self.m): k = self.n + 1 - (1 << (i + 1)) self.d[i + 1] = [0 for _ in range(k)] for j in range(k): self.d[i + 1][j] = min(self.d[i][j], self.d[i][j + (1 << i)]) def min(self, l, r): bit = self.log[r - l] return min(self.d[bit][l], self.d[bit][r - (1 << bit)]) def maxRight(self, idx, jud): ans = idx for i in range(self.m, -1, -1): x = ans - (1 << i) + 1 if x >= 0 and self.d[i][x] >= jud: ans -= 1 << i return ans def minLeft(self, idx, jud): ans = idx for i in range(self.m, -1, -1): if ans < len(self.d[i]) and self.d[i][ans] >= jud: ans += 1 << i return ans class string: def sa_is(s, upper): n = len(s) if n == 0: return [] if n == 1: return [0] if n == 2: if (s[0] < s[1]): return [0, 1] else: return [1, 0] sa = [0]*n ls = [0]*n for i in range(n-2, -1, -1): ls[i] = ls[i+1] if (s[i] == s[i+1]) else (s[i] < s[i+1]) sum_l = [0]*(upper+1) sum_s = [0]*(upper+1) for i in range(n): if not (ls[i]): sum_s[s[i]] += 1 else: sum_l[s[i]+1] += 1 for i in range(upper+1): sum_s[i] += sum_l[i] if i < upper: sum_l[i+1] += sum_s[i] def induce(lms): for i in range(n): sa[i] = -1 buf = sum_s[:] for d in lms: if d == n: continue sa[buf[s[d]]] = d buf[s[d]] += 1 buf = sum_l[:] sa[buf[s[n-1]]] = n-1 buf[s[n-1]] += 1 for i in range(n): v = sa[i] if v >= 1 and not (ls[v-1]): sa[buf[s[v-1]]] = v-1 buf[s[v-1]] += 1 buf = sum_l[:] for i in range(n-1, -1, -1): v = sa[i] if v >= 1 and ls[v-1]: buf[s[v-1]+1] -= 1 sa[buf[s[v-1]+1]] = v-1 lms_map = [-1]*(n+1) m = 0 for i in range(1, n): if not (ls[i-1]) and ls[i]: lms_map[i] = m m += 1 lms = [] for i in range(1, n): if not (ls[i-1]) and ls[i]: lms.append(i) induce(lms) if m: sorted_lms = [] for v in sa: if lms_map[v] != -1: sorted_lms.append(v) rec_s = [0]*m rec_upper = 0 rec_s[lms_map[sorted_lms[0]]] = 0 for i in range(1, m): l = sorted_lms[i-1] r = sorted_lms[i] end_l = lms[lms_map[l]+1] if (lms_map[l]+1 < m) else n end_r = lms[lms_map[r]+1] if (lms_map[r]+1 < m) else n same = True if end_l-l != end_r-r: same = False else: while (l < end_l): if s[l] != s[r]: break l += 1 r += 1 if (l == n) or (s[l] != s[r]): same = False if not (same): rec_upper += 1 rec_s[lms_map[sorted_lms[i]]] = rec_upper rec_sa = string.sa_is(rec_s, rec_upper) for i in range(m): sorted_lms[i] = lms[rec_sa[i]] induce(sorted_lms) return sa def suffix_array_upper(s, upper): assert 0 <= upper for d in s: assert 0 <= d and d <= upper return string.sa_is(s, upper) def suffix_array(s): n = len(s) if type(s) == str: s2 = [ord(i) for i in s] return string.sa_is(s2, 255) else: idx = list(range(n)) idx.sort(key=lambda x: s[x]) s2 = [0]*n now = 0 for i in range(n): if (i & s[idx[i-1]] != s[idx[i]]): now += 1 s2[idx[i]] = now return string.sa_is(s2, now) def lcp_array(s, sa): n = len(s) assert n >= 1 rnk = [0]*n for i in range(n): rnk[sa[i]] = i lcp = [0]*(n-1) h = 0 for i in range(n): if h > 0: h -= 1 if rnk[i] == 0: continue j = sa[rnk[i]-1] while (j+h < n and i+h < n): if s[j+h] != s[i+h]: break h += 1 lcp[rnk[i]-1] = h return lcp def z_algorithm(s): n = len(s) if n == 0: return [] z = [0]*n i = 1 j = 0 while (i < n): z[i] = 0 if (j+z[j] <= i) else min(j+z[j]-i, z[i-j]) while ((i+z[i] < n) and (s[z[i]] == s[i+z[i]])): z[i] += 1 if (j+z[j] < i+z[i]): j = i i += 1 z[0] = n return z def main(): input = sys.stdin.readline n, q = map(int, input().split()) n1 = n - 1 s = input() a = string.suffix_array(s) b = string.lcp_array(s, a) a.pop(0) b.pop(0) a_inv = [0 for _ in range(n)] for i in range(n): a_inv[a[i]] = i len = [0 for _ in range(n)] for i in range(n): len[i] = n - a[i] for i in range(1, n): len[i] += len[i - 1] seg = SparseTable(b) p = [0 for _ in range(n1)] for i in range(n1 - 1, -1, -1): l = seg.minLeft(i + 1, b[i] + 1) p[i] = b[i] * (l - i) if l != n1: p[i] += p[l] for _ in range(q): ll, rr = map(int, input().split()) ll -= 1 l, x = a_inv[ll], rr - ll ans1, ans2 = 1, 0 if l != 0: z = seg.maxRight(l - 1, x) if z != -1: ans2 += len[z] ans1 += l - 1 - z if l != n1: y = seg.minLeft(l, x) if y != n1: ans2 += p[y] ans1 += y - l print(ans1 * (x - 1) + ans2) if __name__ == "__main__": main()