結果

問題 No.2361 Many String Compare Queries
ユーザー 遭難者遭難者
提出日時 2023-01-14 18:53:35
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 615 ms / 2,500 ms
コード長 6,760 bytes
コンパイル時間 150 ms
コンパイル使用メモリ 82,420 KB
実行使用メモリ 127,712 KB
最終ジャッジ日時 2024-06-30 20:40:32
合計ジャッジ時間 5,701 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 39 ms
54,660 KB
testcase_01 AC 42 ms
55,748 KB
testcase_02 AC 40 ms
54,480 KB
testcase_03 AC 41 ms
54,160 KB
testcase_04 AC 41 ms
54,364 KB
testcase_05 AC 39 ms
54,436 KB
testcase_06 AC 40 ms
54,236 KB
testcase_07 AC 51 ms
62,264 KB
testcase_08 AC 615 ms
127,712 KB
testcase_09 AC 601 ms
124,612 KB
testcase_10 AC 614 ms
125,008 KB
testcase_11 AC 430 ms
117,448 KB
testcase_12 AC 465 ms
115,788 KB
testcase_13 AC 426 ms
114,884 KB
testcase_14 AC 339 ms
123,976 KB
testcase_15 AC 419 ms
123,852 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys


class SparseTable:
    def __init__(self, a):
        self.n = len(a)
        self.log = [0 for _ in range(self.n + 1)]
        for i in range(2, self.n + 1):
            self.log[i] = self.log[i >> 1] + 1
        self.m = self.log[self.n]
        self.d = [[] for _ in range(self.m + 1)]
        self.d[0] = a[:]
        for i in range(self.m):
            k = self.n + 1 - (1 << (i + 1))
            self.d[i + 1] = [0 for _ in range(k)]
            for j in range(k):
                self.d[i + 1][j] = min(self.d[i][j], self.d[i][j + (1 << i)])

    def min(self, l, r):
        bit = self.log[r - l]
        return min(self.d[bit][l], self.d[bit][r - (1 << bit)])

    def maxRight(self, idx,  jud):
        ans = idx
        for i in range(self.m, -1, -1):
            x = ans - (1 << i) + 1
            if x >= 0 and self.d[i][x] >= jud:
                ans -= 1 << i
        return ans

    def minLeft(self, idx, jud):
        ans = idx
        for i in range(self.m, -1, -1):
            if ans < len(self.d[i]) and self.d[i][ans] >= jud:
                ans += 1 << i
        return ans


class string:
    def sa_is(s, upper):
        n = len(s)
        if n == 0:
            return []
        if n == 1:
            return [0]
        if n == 2:
            if (s[0] < s[1]):
                return [0, 1]
            else:
                return [1, 0]
        sa = [0]*n
        ls = [0]*n
        for i in range(n-2, -1, -1):
            ls[i] = ls[i+1] if (s[i] == s[i+1]) else (s[i] < s[i+1])
        sum_l = [0]*(upper+1)
        sum_s = [0]*(upper+1)
        for i in range(n):
            if not (ls[i]):
                sum_s[s[i]] += 1
            else:
                sum_l[s[i]+1] += 1
        for i in range(upper+1):
            sum_s[i] += sum_l[i]
            if i < upper:
                sum_l[i+1] += sum_s[i]

        def induce(lms):
            for i in range(n):
                sa[i] = -1
            buf = sum_s[:]
            for d in lms:
                if d == n:
                    continue
                sa[buf[s[d]]] = d
                buf[s[d]] += 1
            buf = sum_l[:]
            sa[buf[s[n-1]]] = n-1
            buf[s[n-1]] += 1
            for i in range(n):
                v = sa[i]
                if v >= 1 and not (ls[v-1]):
                    sa[buf[s[v-1]]] = v-1
                    buf[s[v-1]] += 1
            buf = sum_l[:]
            for i in range(n-1, -1, -1):
                v = sa[i]
                if v >= 1 and ls[v-1]:
                    buf[s[v-1]+1] -= 1
                    sa[buf[s[v-1]+1]] = v-1
        lms_map = [-1]*(n+1)
        m = 0
        for i in range(1, n):
            if not (ls[i-1]) and ls[i]:
                lms_map[i] = m
                m += 1
        lms = []
        for i in range(1, n):
            if not (ls[i-1]) and ls[i]:
                lms.append(i)
        induce(lms)
        if m:
            sorted_lms = []
            for v in sa:
                if lms_map[v] != -1:
                    sorted_lms.append(v)
            rec_s = [0]*m
            rec_upper = 0
            rec_s[lms_map[sorted_lms[0]]] = 0
            for i in range(1, m):
                l = sorted_lms[i-1]
                r = sorted_lms[i]
                end_l = lms[lms_map[l]+1] if (lms_map[l]+1 < m) else n
                end_r = lms[lms_map[r]+1] if (lms_map[r]+1 < m) else n
                same = True
                if end_l-l != end_r-r:
                    same = False
                else:
                    while (l < end_l):
                        if s[l] != s[r]:
                            break
                        l += 1
                        r += 1
                    if (l == n) or (s[l] != s[r]):
                        same = False
                if not (same):
                    rec_upper += 1
                rec_s[lms_map[sorted_lms[i]]] = rec_upper
            rec_sa = string.sa_is(rec_s, rec_upper)
            for i in range(m):
                sorted_lms[i] = lms[rec_sa[i]]
            induce(sorted_lms)
        return sa

    def suffix_array_upper(s, upper):
        assert 0 <= upper
        for d in s:
            assert 0 <= d and d <= upper
        return string.sa_is(s, upper)

    def suffix_array(s):
        n = len(s)
        if type(s) == str:
            s2 = [ord(i) for i in s]
            return string.sa_is(s2, 255)
        else:
            idx = list(range(n))
            idx.sort(key=lambda x: s[x])
            s2 = [0]*n
            now = 0
            for i in range(n):
                if (i & s[idx[i-1]] != s[idx[i]]):
                    now += 1
                s2[idx[i]] = now
            return string.sa_is(s2, now)

    def lcp_array(s, sa):
        n = len(s)
        assert n >= 1
        rnk = [0]*n
        for i in range(n):
            rnk[sa[i]] = i
        lcp = [0]*(n-1)
        h = 0
        for i in range(n):
            if h > 0:
                h -= 1
            if rnk[i] == 0:
                continue
            j = sa[rnk[i]-1]
            while (j+h < n and i+h < n):
                if s[j+h] != s[i+h]:
                    break
                h += 1
            lcp[rnk[i]-1] = h
        return lcp

    def z_algorithm(s):
        n = len(s)
        if n == 0:
            return []
        z = [0]*n
        i = 1
        j = 0
        while (i < n):
            z[i] = 0 if (j+z[j] <= i) else min(j+z[j]-i, z[i-j])
            while ((i+z[i] < n) and (s[z[i]] == s[i+z[i]])):
                z[i] += 1
            if (j+z[j] < i+z[i]):
                j = i
            i += 1
        z[0] = n
        return z


def main():
    input = sys.stdin.readline
    n, q = map(int, input().split())
    n1 = n - 1
    s = input()
    a = string.suffix_array(s)
    b = string.lcp_array(s, a)
    a.pop(0)
    b.pop(0)
    a_inv = [0 for _ in range(n)]
    for i in range(n):
        a_inv[a[i]] = i
    len = [0 for _ in range(n)]
    for i in range(n):
        len[i] = n - a[i]
    for i in range(1, n):
        len[i] += len[i - 1]
    seg = SparseTable(b)
    p = [0 for _ in range(n1)]
    for i in range(n1 - 1, -1, -1):
        l = seg.minLeft(i + 1, b[i] + 1)
        p[i] = b[i] * (l - i)
        if l != n1:
            p[i] += p[l]
    for _ in range(q):
        ll, rr = map(int, input().split())
        ll -= 1
        l, x = a_inv[ll], rr - ll
        ans1, ans2 = 1, 0
        if l != 0:
            z = seg.maxRight(l - 1, x)
            if z != -1:
                ans2 += len[z]
            ans1 += l - 1 - z
        if l != n1:
            y = seg.minLeft(l, x)
            if y != n1:
                ans2 += p[y]
            ans1 += y - l
        print(ans1 * (x - 1) + ans2)


if __name__ == "__main__":
    main()
0