結果
問題 | No.430 文字列検索 |
ユーザー | るこーそー |
提出日時 | 2024-08-16 09:53:49 |
言語 | PyPy3 (7.3.15) |
結果 |
RE
|
実行時間 | - |
コード長 | 4,563 bytes |
コンパイル時間 | 324 ms |
コンパイル使用メモリ | 82,048 KB |
実行使用メモリ | 80,896 KB |
最終ジャッジ日時 | 2024-11-10 01:12:11 |
合計ジャッジ時間 | 2,930 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 55 ms
55,808 KB |
testcase_01 | RE | - |
testcase_02 | AC | 116 ms
78,592 KB |
testcase_03 | RE | - |
testcase_04 | AC | 54 ms
56,320 KB |
testcase_05 | AC | 53 ms
55,936 KB |
testcase_06 | AC | 53 ms
56,192 KB |
testcase_07 | AC | 54 ms
56,448 KB |
testcase_08 | RE | - |
testcase_09 | AC | 60 ms
61,568 KB |
testcase_10 | RE | - |
testcase_11 | RE | - |
testcase_12 | RE | - |
testcase_13 | RE | - |
testcase_14 | RE | - |
testcase_15 | RE | - |
testcase_16 | RE | - |
testcase_17 | RE | - |
ソースコード
import itertools from functools import cmp_to_key def sa_naive(s): n = len(s) sa = list(range(n)) sa.sort(key=lambda x: s[x:]) return sa def sa_doubling(s): n = len(s) sa = list(range(n)) rnk = s[:] tmp = [0] * n k = 1 while k < n: def cmp(x, y): if rnk[x] != rnk[y]: return rnk[x] - rnk[y] rx = rnk[x + k] if x + k < n else -1 ry = rnk[y + k] if y + k < n else -1 return rx - ry sa.sort(key=cmp_to_key(cmp)) tmp[sa[0]] = 0 for i in range(1, n): tmp[sa[i]] = tmp[sa[i - 1]] + (cmp(sa[i - 1], sa[i]) < 0) rnk, tmp = tmp, rnk k *= 2 return sa def sa_is(s, upper): THRESHOLD_NAIVE = 10 THRESHOLD_DOUBLING = 40 n = len(s) if n == 0: return [] if n == 1: return [0] if n == 2: return [0, 1] if s[0] < s[1] else [1, 0] if n < THRESHOLD_NAIVE: return sa_naive(s) if n < THRESHOLD_DOUBLING: return sa_doubling(s) sa = [-1] * n ls = [False] * n for i in range(n - 2, -1, -1): ls[i] = s[i] < s[i + 1] or (s[i] == s[i + 1] and ls[i + 1]) sum_l = [0] * (upper + 1) sum_s = [0] * (upper + 1) for i in range(n): if ls[i]: sum_l[s[i] + 1] += 1 else: sum_s[s[i]] += 1 for i in range(upper + 1): sum_s[i] += sum_l[i] if i < upper: sum_l[i + 1] += sum_s[i] def induce(lms): buf = sum_s[:] for d in lms: if d == n: continue sa[buf[s[d]]] = d buf[s[d]] += 1 buf = sum_l[:] sa[buf[s[n - 1]]] = n - 1 buf[s[n - 1]] += 1 for i in range(n): v = sa[i] if v >= 1 and not ls[v - 1]: sa[buf[s[v - 1]]] = v - 1 buf[s[v - 1]] += 1 buf = sum_l[:] for i in range(n - 1, -1, -1): v = sa[i] if v >= 1 and ls[v - 1]: buf[s[v - 1] + 1] -= 1 sa[buf[s[v - 1] + 1]] = v - 1 lms_map = [-1] * (n + 1) m = 0 for i in range(1, n): if not ls[i - 1] and ls[i]: lms_map[i] = m m += 1 lms = [i for i in range(1, n) if not ls[i - 1] and ls[i]] induce(lms) if m: sorted_lms = [] for v in sa: if lms_map[v] != -1: sorted_lms.append(v) rec_s = [0] * m rec_upper = 0 rec_s[lms_map[sorted_lms[0]]] = 0 for i in range(1, m): l = sorted_lms[i - 1] r = sorted_lms[i] end_l = lms[lms_map[l] + 1] if lms_map[l] + 1 < m else n end_r = lms[lms_map[r] + 1] if lms_map[r] + 1 < m else n same = True if end_l - l != end_r - r: same = False else: while l < end_l: if s[l] != s[r]: same = False break l += 1 r += 1 if not same: rec_upper += 1 rec_s[lms_map[sorted_lms[i]]] = rec_upper rec_sa = sa_is(rec_s, rec_upper) for i in range(m): sorted_lms[i] = lms[rec_sa[i]] induce(sorted_lms) return sa def suffix_array(s, upper=None): if isinstance(s, str): s = [ord(c) for c in s] if upper is None: upper = max(s) return sa_is(s, upper) def lcp_array(s, sa): n = len(s) rnk = [0] * n for i, suffix in enumerate(sa): rnk[suffix] = i lcp = [0] * (n - 1) h = 0 for i in range(n): if rnk[i] == 0: continue j = sa[rnk[i] - 1] while i + h < n and j + h < n and s[i + h] == s[j + h]: h += 1 lcp[rnk[i] - 1] = h if h > 0: h -= 1 return lcp def z_algorithm(s): n = len(s) z = [0] * n z[0] = n l, r = 0, 0 for i in range(1, n): if i <= r: z[i] = min(r - i + 1, z[i - l]) while i + z[i] < n and s[z[i]] == s[i + z[i]]: z[i] += 1 if i + z[i] - 1 > r: l, r = i, i + z[i] - 1 return z S=input() SA=suffix_array(S) def bisect_left(C): ng,ok=-1,len(S) while ok-ng>1: mid=(ok+ng)//2 if C<=S[SA[mid]:SA[mid]+len(C)]:ok=mid else:ng=mid return ok m=int(input()) ans=0 for _ in range(m): C=input() ans+=bisect_left(C+'~')-bisect_left(C) print(ans)