結果

問題 No.515 典型LCP
ユーザー shotoyooshotoyoo
提出日時 2020-09-20 16:58:21
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 10,066 bytes
コンパイル時間 375 ms
コンパイル使用メモリ 86,996 KB
実行使用メモリ 265,496 KB
最終ジャッジ日時 2023-09-06 16:22:18
合計ジャッジ時間 5,051 ms
ジャッジサーバーID
(参考情報)
judge15 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 TLE -
testcase_01 -- -
testcase_02 -- -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = lambda : sys.stdin.readline().rstrip()
sys.setrecursionlimit(max(1000, 10**9))
write = lambda x: sys.stdout.write(x+"\n")


### セグメント木
class SegmentTree:
    def __init__(self, n, a=None):
        """初期化
        num : n以上の最小の2のべき乗
        """
        num = 1
        while num<=n:
            num *= 2
        self.num = num
        self.seg = [ninf] * (2*self.num-1)
        if a is not None:
            # O(n)で初期化
            assert len(a)==n
            for i in range(n):
                self.seg[num-1+i] = a[i]
            for k in range(num-2, -1, -1):
                self.seg[k] = op(self.seg[2*k+1], self.seg[2*k+2])
    def update(self,i,x):
        """update(i,x):Aiをxに更新する
        """
        k = i+(self.num-1)
        self.seg[k] = x
        k = (k-1)//2
        while k >= 0:
            self.seg[k] = op(self.seg[2*k+1], self.seg[2*k+2])
            k = (k-1)//2
    def query(self,a,b):
        k = 0
        l = 0
        r = self.num
        q = [(k,l,r)]
        ans = ninf
        # 重なる区間を深さ優先探索
        while q:
            k,l,r = q.pop()
            if r<=a or b<=l:
                pass
            elif a<=l and r<=b:
                ans = op(ans, self.seg[k])
            else:
                q.append((2*k+1,l,(l+r)//2))
                q.append((2*k+2,(l+r)//2,r)) 
        return ans
    def find_right(self,a,b,x=None,f=None):
        """[a,b)で値がx以上のインデックスの最大
        存在しない場合-1を返す
        """
        if f is None:
            f = lambda y: y>=x
        k = 0
        l = 0
        r = self.num
        q = [(k,l,r,True)] # 行きがけかどうか
        ans = -1
        while q:
            k,l,r,flg = q.pop()
            if flg:
                if not f(self.seg[k]) or r<=a or b<=l: # 条件を満たせない or 区間が重複しない
                    pass
                elif k>=self.num-1: # 自身が葉
                    ans = max(ans, k - (self.num-1))
                    return ans
                else:
                    # 左への探索を予約
                    q.append((2*k+1,l,(l+r)//2,False))
                    # 右への探索
                    q.append((2*k+2,(l+r)//2,r,True))
            else:
                if ans>=0:
                    return ans
                q.append((k,l,r,True))
        return ans
    def find_left(self,a,b,x=None, f=None):
        """[a,b)で値がx以上のインデックス(0,1,...,self.num-1)の最小
        条件を満たすものが存在しないとき、self.numを返す
        """
        if f is None:
            f = lambda y: y>=x
        k = 0
        l = 0
        r = self.num
        q = [(k,l,r,True)] # 行きがけかどうか
        ans = self.num
        while q:
            k,l,r,flg = q.pop()
            if flg:
                if not f(self.seg[k]) or r<=a or b<=l: # x以上を満たせない or 区間が重複しない
                    continue
                elif k>=self.num-1: # 自身が葉
                    ans = min(ans, k - (self.num-1))
                    return ans
                else:
                    # 右への探索を予約
                    q.append((2*k+2,(l+r)//2,r,False))
                    # 左への探索
                    q.append((2*k+1,l,(l+r)//2,True))
            else:
                if ans<self.num:
                    return ans
                q.append((k,l,r,True))
        return ans
    def query_index(self,a,b,k=0,l=0,r=None):
        """query(a,b,0,0,num):[a,b)の最大値
        最大値を与えるインデックスも返す
        """
        if r is None:
            r = self.num
        if r <= a or b <= l:
            return (ninf, None)
        elif a <= l and r <= b:
            return (self.seg[k], self._index(k))
        else:
            return op(self.query_index(a,b,2*k+1,l,(l+r)//2),self.query_index(a,b,2*k+2,(l+r)//2,r))
    def _index(self, k):
        if k>=self.num:
            return k - (self.num-1)
        else:
            if self.seg[2*k+1]>=self.seg[2*k+2]:
                return self._index(2*k+1)
            else:
                return self._index(2*k+2)

from copy import copy
def sa_naive(s):
    n = len(s)
    sa = list(range(n))
    sa.sort(key=lambda x: s[x:])
    return sa

def sa_doubling(s):
    n = len(s)
    sa = list(range(n))
    rnk = copy(s)
    tmp = [0] * n
    k = 1
    while k < n:
        sa.sort(key=lambda x: (rnk[x], rnk[x + k])
                if x + k < n else (rnk[x], -1))
        tmp[sa[0]] = 0
        for i in range(1, n):
            tmp[sa[i]] = tmp[sa[i - 1]]
            if sa[i - 1] + k < n:
                x = (rnk[sa[i - 1]], rnk[sa[i - 1] + k])
            else:
                x = (rnk[sa[i - 1]], -1)
            if sa[i] + k < n:
                y = (rnk[sa[i]], rnk[sa[i] + k])
            else:
                y = (rnk[sa[i]], -1)
            if x < y:
                tmp[sa[i]] += 1
        k *= 2
        tmp, rnk = rnk, tmp
    return sa

def sa_is(s, upper):
    n = len(s)
    if n == 0:
        return []
    if n == 1:
        return [0]
    if n == 2:
        if s[0] < s[1]:
            return [0, 1]
        else:
            return [1, 0]
    if n < 10:
        return sa_naive(s)
    if n < 50:
        return sa_doubling(s)
    ls = [0] * n
    for i in range(n-2, -1, -1):
        ls[i] = ls[i + 1] if s[i] == s[i + 1] else s[i] < s[i + 1]
    sum_l = [0] * (upper + 1)
    sum_s = [0] * (upper + 1)
    for i in range(n):
        if ls[i]:
            sum_l[s[i] + 1] += 1
        else:
            sum_s[s[i]] += 1
    for i in range(upper):
        sum_s[i] += sum_l[i]
        if i < upper:
            sum_l[i + 1] += sum_s[i]
    lms_map = [-1] * (n + 1)
    m = 0
    for i in range(1, n):
        if not ls[i - 1] and ls[i]:
            lms_map[i] = m
            m += 1
    lms = []
    for i in range(1, n):
        if not ls[i - 1] and ls[i]:
            lms.append(i)
    sa = [-1] * n
    buf = sum_s.copy()
    for d in lms:
        if d == n:
            continue
        sa[buf[s[d]]] = d
        buf[s[d]] += 1
    buf = sum_l.copy()
    sa[buf[s[n - 1]]] = n - 1
    buf[s[n - 1]] += 1
    for i in range(n):
        v = sa[i]
        if v >= 1 and not ls[v - 1]:
            sa[buf[s[v - 1]]] = v - 1
            buf[s[v - 1]] += 1
    buf = sum_l.copy()
    for i in range(n-1, -1, -1):
        v = sa[i]
        if v >= 1 and ls[v - 1]:
            buf[s[v - 1] + 1] -= 1
            sa[buf[s[v - 1] + 1]] = v - 1
    if m:
        sorted_lms = []
        for v in sa:
            if lms_map[v] != -1:
                sorted_lms.append(v)
        rec_s = [0] * m
        rec_upper = 0
        rec_s[lms_map[sorted_lms[0]]] = 0
        for i in range(1, m):
            l = sorted_lms[i - 1]
            r = sorted_lms[i]
            end_l = lms[lms_map[l] + 1] if lms_map[l] + 1 < m else n
            end_r = lms[lms_map[r] + 1] if lms_map[r] + 1 < m else n
            same = True
            if end_l - l != end_r - r:
                same = False
            else:
                while l < end_l:
                    if s[l] != s[r]:
                        break
                    l += 1
                    r += 1
                if l == n or s[l] != s[r]:
                    same = False
            if not same:
                rec_upper += 1
            rec_s[lms_map[sorted_lms[i]]] = rec_upper
        rec_sa = sa_is(rec_s, rec_upper)
        for i in range(m):
            sorted_lms[i] = lms[rec_sa[i]]
        sa = [-1] * n
        buf = sum_s.copy()
        for d in sorted_lms:
            if d == n:
                continue
            sa[buf[s[d]]] = d
            buf[s[d]] += 1
        buf = sum_l.copy()
        sa[buf[s[n - 1]]] = n - 1
        buf[s[n - 1]] += 1
        for i in range(n):
            v = sa[i]
            if v >= 1 and not ls[v - 1]:
                sa[buf[s[v - 1]]] = v - 1
                buf[s[v - 1]] += 1
        buf = sum_l.copy()
        for i in range(n-1, -1, -1):
            v = sa[i]
            if v >= 1 and ls[v - 1]:
                buf[s[v - 1] + 1] -= 1
                sa[buf[s[v - 1] + 1]] = v - 1
    return sa

def suffix_array(s, upper=None):
    """sは文字列または0以上の整数のリスト
    """
    if isinstance(s, str):
        s = [ord(c) for c in s]
    if upper is None:
        upper = max(s)
    return sa_is(s, upper)

def lcp_array(s, sa):
    n = len(s)
    rnk = [0] * n
    for i in range(n):
        rnk[sa[i]] = i
    lcp = [0] * (n - 1)
    h = 0
    for i in range(n):
        if h > 0:
            h -= 1
        if rnk[i] == 0:
            continue
        j = sa[rnk[i] - 1]
        while j + h < n and i + h < n:
            if s[j + h] != s[i + h]:
                break
            h += 1
        lcp[rnk[i] - 1] = h
    return lcp

def preproc(l):
    """リスト中の文字列→ ord()で置き換え
    それ以外→ ord()の最大値より大きい値で置き換え
    """
    d = {}
    i = ord("z") + 1000
    for ind,item in enumerate(s):
        if item not in d:
            if isinstance(item, str) and len(item)==1:
                d[item] = ord(item)
            else:
                d[item] = i
                i += 1
        s[ind] = d[item]


n = int(input())
s = []
# tmp = [None]*n
tmp = {}
for i in range(n):
#     tmp[i] = len(s)
    tmp[len(s)] = i
    s.extend(list(input()))
    s.append(i)
preproc(s)
sa = suffix_array(s)
rsa = [None]*n
for i,item in enumerate(sa):
    if item in tmp:
        rsa[tmp[item]] = i
lcp = lcp_array(s, sa)
ninf = 10**9
op = min
sg = SegmentTree(len(lcp), lcp)

M,x,d = map(int, input().split())
ans = 0
def sub(i,j):
    ii,jj = rsa[i], rsa[j]
    if ii>jj:
        ii,jj = jj,ii
    return sg.query(ii,jj)
for k in range(1, M+1):
    i = (x // (n-1)) + 1
    j = (x % (n-1)) + 1
    if i>j:
        i,j = j,i
    else:
        j += 1
    x = (x + d) % (n * (n-1))
#     print(i,j)
    ans += sub(i-1,j-1)
print(ans)
0