結果

問題 No.430 文字列検索
ユーザー るこーそーるこーそー
提出日時 2024-08-16 09:53:49
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 4,563 bytes
コンパイル時間 324 ms
コンパイル使用メモリ 82,048 KB
実行使用メモリ 80,896 KB
最終ジャッジ日時 2024-11-10 01:12:11
合計ジャッジ時間 2,930 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 55 ms
55,808 KB
testcase_01 RE -
testcase_02 AC 116 ms
78,592 KB
testcase_03 RE -
testcase_04 AC 54 ms
56,320 KB
testcase_05 AC 53 ms
55,936 KB
testcase_06 AC 53 ms
56,192 KB
testcase_07 AC 54 ms
56,448 KB
testcase_08 RE -
testcase_09 AC 60 ms
61,568 KB
testcase_10 RE -
testcase_11 RE -
testcase_12 RE -
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

import itertools
from functools import cmp_to_key

def sa_naive(s):
    n = len(s)
    sa = list(range(n))
    sa.sort(key=lambda x: s[x:])
    return sa

def sa_doubling(s):
    n = len(s)
    sa = list(range(n))
    rnk = s[:]
    tmp = [0] * n
    k = 1
    while k < n:
        def cmp(x, y):
            if rnk[x] != rnk[y]:
                return rnk[x] - rnk[y]
            rx = rnk[x + k] if x + k < n else -1
            ry = rnk[y + k] if y + k < n else -1
            return rx - ry

        sa.sort(key=cmp_to_key(cmp))
        tmp[sa[0]] = 0
        for i in range(1, n):
            tmp[sa[i]] = tmp[sa[i - 1]] + (cmp(sa[i - 1], sa[i]) < 0)
        rnk, tmp = tmp, rnk
        k *= 2
    return sa

def sa_is(s, upper):
    THRESHOLD_NAIVE = 10
    THRESHOLD_DOUBLING = 40
    n = len(s)
    if n == 0:
        return []
    if n == 1:
        return [0]
    if n == 2:
        return [0, 1] if s[0] < s[1] else [1, 0]
    if n < THRESHOLD_NAIVE:
        return sa_naive(s)
    if n < THRESHOLD_DOUBLING:
        return sa_doubling(s)

    sa = [-1] * n
    ls = [False] * n
    for i in range(n - 2, -1, -1):
        ls[i] = s[i] < s[i + 1] or (s[i] == s[i + 1] and ls[i + 1])

    sum_l = [0] * (upper + 1)
    sum_s = [0] * (upper + 1)
    for i in range(n):
        if ls[i]:
            sum_l[s[i] + 1] += 1
        else:
            sum_s[s[i]] += 1
    for i in range(upper + 1):
        sum_s[i] += sum_l[i]
        if i < upper:
            sum_l[i + 1] += sum_s[i]

    def induce(lms):
        buf = sum_s[:]
        for d in lms:
            if d == n:
                continue
            sa[buf[s[d]]] = d
            buf[s[d]] += 1

        buf = sum_l[:]
        sa[buf[s[n - 1]]] = n - 1
        buf[s[n - 1]] += 1
        for i in range(n):
            v = sa[i]
            if v >= 1 and not ls[v - 1]:
                sa[buf[s[v - 1]]] = v - 1
                buf[s[v - 1]] += 1

        buf = sum_l[:]
        for i in range(n - 1, -1, -1):
            v = sa[i]
            if v >= 1 and ls[v - 1]:
                buf[s[v - 1] + 1] -= 1
                sa[buf[s[v - 1] + 1]] = v - 1

    lms_map = [-1] * (n + 1)
    m = 0
    for i in range(1, n):
        if not ls[i - 1] and ls[i]:
            lms_map[i] = m
            m += 1

    lms = [i for i in range(1, n) if not ls[i - 1] and ls[i]]

    induce(lms)

    if m:
        sorted_lms = []
        for v in sa:
            if lms_map[v] != -1:
                sorted_lms.append(v)

        rec_s = [0] * m
        rec_upper = 0
        rec_s[lms_map[sorted_lms[0]]] = 0
        for i in range(1, m):
            l = sorted_lms[i - 1]
            r = sorted_lms[i]
            end_l = lms[lms_map[l] + 1] if lms_map[l] + 1 < m else n
            end_r = lms[lms_map[r] + 1] if lms_map[r] + 1 < m else n
            same = True
            if end_l - l != end_r - r:
                same = False
            else:
                while l < end_l:
                    if s[l] != s[r]:
                        same = False
                        break
                    l += 1
                    r += 1
            if not same:
                rec_upper += 1
            rec_s[lms_map[sorted_lms[i]]] = rec_upper

        rec_sa = sa_is(rec_s, rec_upper)

        for i in range(m):
            sorted_lms[i] = lms[rec_sa[i]]

        induce(sorted_lms)

    return sa

def suffix_array(s, upper=None):
    if isinstance(s, str):
        s = [ord(c) for c in s]
    if upper is None:
        upper = max(s)
    return sa_is(s, upper)

def lcp_array(s, sa):
    n = len(s)
    rnk = [0] * n
    for i, suffix in enumerate(sa):
        rnk[suffix] = i
    lcp = [0] * (n - 1)
    h = 0
    for i in range(n):
        if rnk[i] == 0:
            continue
        j = sa[rnk[i] - 1]
        while i + h < n and j + h < n and s[i + h] == s[j + h]:
            h += 1
        lcp[rnk[i] - 1] = h
        if h > 0:
            h -= 1
    return lcp

def z_algorithm(s):
    n = len(s)
    z = [0] * n
    z[0] = n
    l, r = 0, 0
    for i in range(1, n):
        if i <= r:
            z[i] = min(r - i + 1, z[i - l])
        while i + z[i] < n and s[z[i]] == s[i + z[i]]:
            z[i] += 1
        if i + z[i] - 1 > r:
            l, r = i, i + z[i] - 1
    return z



S=input()
SA=suffix_array(S)

def bisect_left(C):
    ng,ok=-1,len(S)
    while ok-ng>1:
        mid=(ok+ng)//2
        if C<=S[SA[mid]:SA[mid]+len(C)]:ok=mid
        else:ng=mid
    return ok

m=int(input())
ans=0
for _ in range(m):
    C=input()
    ans+=bisect_left(C+'~')-bisect_left(C)
print(ans)
0