結果

問題 No.2361 Many String Compare Queries
ユーザー 遭難者遭難者
提出日時 2023-01-13 15:27:48
言語 PyPy3
(7.3.15)
結果
RE  
(最新)
AC  
(最初)
実行時間 -
コード長 7,593 bytes
コンパイル時間 250 ms
コンパイル使用メモリ 81,772 KB
実行使用メモリ 115,512 KB
最終ジャッジ日時 2024-06-30 20:39:55
合計ジャッジ時間 11,168 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 41 ms
55,436 KB
testcase_01 AC 41 ms
53,912 KB
testcase_02 AC 44 ms
55,936 KB
testcase_03 AC 42 ms
55,492 KB
testcase_04 AC 41 ms
55,632 KB
testcase_05 AC 41 ms
54,772 KB
testcase_06 AC 40 ms
54,928 KB
testcase_07 AC 62 ms
68,900 KB
testcase_08 AC 2,008 ms
115,512 KB
testcase_09 AC 2,069 ms
113,004 KB
testcase_10 AC 2,019 ms
113,452 KB
testcase_11 RE -
testcase_12 RE -
testcase_13 RE -
testcase_14 AC 1,181 ms
106,256 KB
testcase_15 AC 1,712 ms
111,208 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys


class SegTree:
    def __init__(self, d, INF):
        n = len(d)
        n2 = 1
        while n2 < n:
            n2 <<= 1
        self.N2 = n2
        self.N = n2 << 1
        self.data = [INF for _ in range(self.N)]
        idx = self.N2
        for i in d:
            self.data[idx] = i
            idx += 1
        for i in range(self.N2 - 1, 0, -1):
            ii = i << 1
            self.data[i] = min(self.data[ii], self.data[ii | 1])
        self.INF = INF

    def prod(self, a, b, k, l, r):
        if r <= a or b <= l:
            return self.INF
        if a <= l and r <= b:
            return self.data[k]
        ty = l + r >> 1
        ii = k << 1
        return min(self.prod(a, b, ii, l, ty), self.prod(a, b, ii | 1, ty, r))

    def prod(self, l, r):
        return self.prod(l, r, 1, 0, self.N2)

    def set(self, i, x):
        i += self.N2
        self.data[i] = x
        while i != 1:
            i >>= 1
            ii = i << 1
            self.data[i] = min(self.data[ii], self.data[ii | 1])

    def maxRight0(self, idx, jud, k, l, r):
        if self.data[k] >= jud or idx < l:
            return -1
        if k >= self.N2:
            return k - self.N2
        ii = k << 1
        ty = l + r >> 1
        koh = self.maxRight0(idx, jud, ii | 1, ty, r)
        if koh != -1:
            return koh
        return self.maxRight0(idx, jud, ii, l, ty)

    def maxRight(self, idx, jud):
        return self.maxRight0(idx, jud, 1, 0, self.N2)

    def minLeft0(self, idx, jud, k, l, r):
        if self.data[k] >= jud or r <= idx:
            return self.INF
        if k >= self.N2:
            return k - self.N2
        ii = k << 1
        ty = l + r >> 1
        koh = self.minLeft0(idx, jud, ii, l, ty)
        if koh != self.INF:
            return koh
        return self.minLeft0(idx, jud, ii | 1, ty, r)

    def minLeft(self, idx, jud):
        return self.minLeft0(idx, jud, 1, 0, self.N2)


class string:
    def sa_is(s, upper):
        n = len(s)
        if n == 0:
            return []
        if n == 1:
            return [0]
        if n == 2:
            if (s[0] < s[1]):
                return [0, 1]
            else:
                return [1, 0]
        sa = [0]*n
        ls = [0]*n
        for i in range(n-2, -1, -1):
            ls[i] = ls[i+1] if (s[i] == s[i+1]) else (s[i] < s[i+1])
        sum_l = [0]*(upper+1)
        sum_s = [0]*(upper+1)
        for i in range(n):
            if not (ls[i]):
                sum_s[s[i]] += 1
            else:
                sum_l[s[i]+1] += 1
        for i in range(upper+1):
            sum_s[i] += sum_l[i]
            if i < upper:
                sum_l[i+1] += sum_s[i]

        def induce(lms):
            for i in range(n):
                sa[i] = -1
            buf = sum_s[:]
            for d in lms:
                if d == n:
                    continue
                sa[buf[s[d]]] = d
                buf[s[d]] += 1
            buf = sum_l[:]
            sa[buf[s[n-1]]] = n-1
            buf[s[n-1]] += 1
            for i in range(n):
                v = sa[i]
                if v >= 1 and not (ls[v-1]):
                    sa[buf[s[v-1]]] = v-1
                    buf[s[v-1]] += 1
            buf = sum_l[:]
            for i in range(n-1, -1, -1):
                v = sa[i]
                if v >= 1 and ls[v-1]:
                    buf[s[v-1]+1] -= 1
                    sa[buf[s[v-1]+1]] = v-1
        lms_map = [-1]*(n+1)
        m = 0
        for i in range(1, n):
            if not (ls[i-1]) and ls[i]:
                lms_map[i] = m
                m += 1
        lms = []
        for i in range(1, n):
            if not (ls[i-1]) and ls[i]:
                lms.append(i)
        induce(lms)
        if m:
            sorted_lms = []
            for v in sa:
                if lms_map[v] != -1:
                    sorted_lms.append(v)
            rec_s = [0]*m
            rec_upper = 0
            rec_s[lms_map[sorted_lms[0]]] = 0
            for i in range(1, m):
                l = sorted_lms[i-1]
                r = sorted_lms[i]
                end_l = lms[lms_map[l]+1] if (lms_map[l]+1 < m) else n
                end_r = lms[lms_map[r]+1] if (lms_map[r]+1 < m) else n
                same = True
                if end_l-l != end_r-r:
                    same = False
                else:
                    while (l < end_l):
                        if s[l] != s[r]:
                            break
                        l += 1
                        r += 1
                    if (l == n) or (s[l] != s[r]):
                        same = False
                if not (same):
                    rec_upper += 1
                rec_s[lms_map[sorted_lms[i]]] = rec_upper
            rec_sa = string.sa_is(rec_s, rec_upper)
            for i in range(m):
                sorted_lms[i] = lms[rec_sa[i]]
            induce(sorted_lms)
        return sa

    def suffix_array_upper(s, upper):
        assert 0 <= upper
        for d in s:
            assert 0 <= d and d <= upper
        return string.sa_is(s, upper)

    def suffix_array(s):
        n = len(s)
        if type(s) == str:
            s2 = [ord(i) for i in s]
            return string.sa_is(s2, 255)
        else:
            idx = list(range(n))
            idx.sort(key=lambda x: s[x])
            s2 = [0]*n
            now = 0
            for i in range(n):
                if (i & s[idx[i-1]] != s[idx[i]]):
                    now += 1
                s2[idx[i]] = now
            return string.sa_is(s2, now)

    def lcp_array(s, sa):
        n = len(s)
        assert n >= 1
        rnk = [0]*n
        for i in range(n):
            rnk[sa[i]] = i
        lcp = [0]*(n-1)
        h = 0
        for i in range(n):
            if h > 0:
                h -= 1
            if rnk[i] == 0:
                continue
            j = sa[rnk[i]-1]
            while (j+h < n and i+h < n):
                if s[j+h] != s[i+h]:
                    break
                h += 1
            lcp[rnk[i]-1] = h
        return lcp

    def z_algorithm(s):
        n = len(s)
        if n == 0:
            return []
        z = [0]*n
        i = 1
        j = 0
        while (i < n):
            z[i] = 0 if (j+z[j] <= i) else min(j+z[j]-i, z[i-j])
            while ((i+z[i] < n) and (s[z[i]] == s[i+z[i]])):
                z[i] += 1
            if (j+z[j] < i+z[i]):
                j = i
            i += 1
        z[0] = n
        return z


def main():
    input = sys.stdin.readline
    n, q = map(int, input().split())
    n1 = n - 1
    s = input()
    a = string.suffix_array(s)
    b = string.lcp_array(s, a)
    a.pop(0)
    b.pop(0)
    a_inv = [0 for _ in range(n)]
    for i in range(n):
        a_inv[a[i]] = i
    len = [0 for _ in range(n)]
    for i in range(n):
        len[i] = n - a[i]
    for i in range(1, n):
        len[i] += len[i - 1]
    seg = SegTree(b, n1)
    p = [0 for _ in range(n1)]
    for i in range(n1 - 1, -1, -1):
        l = seg.minLeft(i + 1, b[i] + 1)
        p[i] = b[i] * (l - i)
        if l != n1:
            p[i] += p[l]
    for _ in range(q):
        ll, rr = map(int, input().split())
        ll -= 1
        l, x = a_inv[ll], rr - ll
        ans1, ans2 = 1, 0
        if l != 0:
            z = seg.maxRight(l - 1, x)
            if z != -1:
                ans2 += len[z]
            ans1 += l - 1 - z
        if l != n1:
            y = seg.minLeft(l, x)
            if y != n1:
                ans2 += p[y]
            ans1 += y - l
        print(ans1 * (x - 1) + ans2)


if __name__ == "__main__":
    main()
0