結果
問題 | No.515 典型LCP |
ユーザー | shotoyoo |
提出日時 | 2020-09-20 16:58:21 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 10,066 bytes |
コンパイル時間 | 202 ms |
コンパイル使用メモリ | 82,304 KB |
実行使用メモリ | 275,776 KB |
最終ジャッジ日時 | 2024-06-24 10:48:05 |
合計ジャッジ時間 | 4,892 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | TLE | - |
testcase_01 | -- | - |
testcase_02 | -- | - |
testcase_03 | -- | - |
testcase_04 | -- | - |
testcase_05 | -- | - |
testcase_06 | -- | - |
testcase_07 | -- | - |
testcase_08 | -- | - |
testcase_09 | -- | - |
testcase_10 | -- | - |
testcase_11 | -- | - |
testcase_12 | -- | - |
testcase_13 | -- | - |
testcase_14 | -- | - |
testcase_15 | -- | - |
testcase_16 | -- | - |
ソースコード
import sys input = lambda : sys.stdin.readline().rstrip() sys.setrecursionlimit(max(1000, 10**9)) write = lambda x: sys.stdout.write(x+"\n") ### セグメント木 class SegmentTree: def __init__(self, n, a=None): """初期化 num : n以上の最小の2のべき乗 """ num = 1 while num<=n: num *= 2 self.num = num self.seg = [ninf] * (2*self.num-1) if a is not None: # O(n)で初期化 assert len(a)==n for i in range(n): self.seg[num-1+i] = a[i] for k in range(num-2, -1, -1): self.seg[k] = op(self.seg[2*k+1], self.seg[2*k+2]) def update(self,i,x): """update(i,x):Aiをxに更新する """ k = i+(self.num-1) self.seg[k] = x k = (k-1)//2 while k >= 0: self.seg[k] = op(self.seg[2*k+1], self.seg[2*k+2]) k = (k-1)//2 def query(self,a,b): k = 0 l = 0 r = self.num q = [(k,l,r)] ans = ninf # 重なる区間を深さ優先探索 while q: k,l,r = q.pop() if r<=a or b<=l: pass elif a<=l and r<=b: ans = op(ans, self.seg[k]) else: q.append((2*k+1,l,(l+r)//2)) q.append((2*k+2,(l+r)//2,r)) return ans def find_right(self,a,b,x=None,f=None): """[a,b)で値がx以上のインデックスの最大 存在しない場合-1を返す """ if f is None: f = lambda y: y>=x k = 0 l = 0 r = self.num q = [(k,l,r,True)] # 行きがけかどうか ans = -1 while q: k,l,r,flg = q.pop() if flg: if not f(self.seg[k]) or r<=a or b<=l: # 条件を満たせない or 区間が重複しない pass elif k>=self.num-1: # 自身が葉 ans = max(ans, k - (self.num-1)) return ans else: # 左への探索を予約 q.append((2*k+1,l,(l+r)//2,False)) # 右への探索 q.append((2*k+2,(l+r)//2,r,True)) else: if ans>=0: return ans q.append((k,l,r,True)) return ans def find_left(self,a,b,x=None, f=None): """[a,b)で値がx以上のインデックス(0,1,...,self.num-1)の最小 条件を満たすものが存在しないとき、self.numを返す """ if f is None: f = lambda y: y>=x k = 0 l = 0 r = self.num q = [(k,l,r,True)] # 行きがけかどうか ans = self.num while q: k,l,r,flg = q.pop() if flg: if not f(self.seg[k]) or r<=a or b<=l: # x以上を満たせない or 区間が重複しない continue elif k>=self.num-1: # 自身が葉 ans = min(ans, k - (self.num-1)) return ans else: # 右への探索を予約 q.append((2*k+2,(l+r)//2,r,False)) # 左への探索 q.append((2*k+1,l,(l+r)//2,True)) else: if ans<self.num: return ans q.append((k,l,r,True)) return ans def query_index(self,a,b,k=0,l=0,r=None): """query(a,b,0,0,num):[a,b)の最大値 最大値を与えるインデックスも返す """ if r is None: r = self.num if r <= a or b <= l: return (ninf, None) elif a <= l and r <= b: return (self.seg[k], self._index(k)) else: return op(self.query_index(a,b,2*k+1,l,(l+r)//2),self.query_index(a,b,2*k+2,(l+r)//2,r)) def _index(self, k): if k>=self.num: return k - (self.num-1) else: if self.seg[2*k+1]>=self.seg[2*k+2]: return self._index(2*k+1) else: return self._index(2*k+2) from copy import copy def sa_naive(s): n = len(s) sa = list(range(n)) sa.sort(key=lambda x: s[x:]) return sa def sa_doubling(s): n = len(s) sa = list(range(n)) rnk = copy(s) tmp = [0] * n k = 1 while k < n: sa.sort(key=lambda x: (rnk[x], rnk[x + k]) if x + k < n else (rnk[x], -1)) tmp[sa[0]] = 0 for i in range(1, n): tmp[sa[i]] = tmp[sa[i - 1]] if sa[i - 1] + k < n: x = (rnk[sa[i - 1]], rnk[sa[i - 1] + k]) else: x = (rnk[sa[i - 1]], -1) if sa[i] + k < n: y = (rnk[sa[i]], rnk[sa[i] + k]) else: y = (rnk[sa[i]], -1) if x < y: tmp[sa[i]] += 1 k *= 2 tmp, rnk = rnk, tmp return sa def sa_is(s, upper): n = len(s) if n == 0: return [] if n == 1: return [0] if n == 2: if s[0] < s[1]: return [0, 1] else: return [1, 0] if n < 10: return sa_naive(s) if n < 50: return sa_doubling(s) ls = [0] * n for i in range(n-2, -1, -1): ls[i] = ls[i + 1] if s[i] == s[i + 1] else s[i] < s[i + 1] sum_l = [0] * (upper + 1) sum_s = [0] * (upper + 1) for i in range(n): if ls[i]: sum_l[s[i] + 1] += 1 else: sum_s[s[i]] += 1 for i in range(upper): sum_s[i] += sum_l[i] if i < upper: sum_l[i + 1] += sum_s[i] lms_map = [-1] * (n + 1) m = 0 for i in range(1, n): if not ls[i - 1] and ls[i]: lms_map[i] = m m += 1 lms = [] for i in range(1, n): if not ls[i - 1] and ls[i]: lms.append(i) sa = [-1] * n buf = sum_s.copy() for d in lms: if d == n: continue sa[buf[s[d]]] = d buf[s[d]] += 1 buf = sum_l.copy() sa[buf[s[n - 1]]] = n - 1 buf[s[n - 1]] += 1 for i in range(n): v = sa[i] if v >= 1 and not ls[v - 1]: sa[buf[s[v - 1]]] = v - 1 buf[s[v - 1]] += 1 buf = sum_l.copy() for i in range(n-1, -1, -1): v = sa[i] if v >= 1 and ls[v - 1]: buf[s[v - 1] + 1] -= 1 sa[buf[s[v - 1] + 1]] = v - 1 if m: sorted_lms = [] for v in sa: if lms_map[v] != -1: sorted_lms.append(v) rec_s = [0] * m rec_upper = 0 rec_s[lms_map[sorted_lms[0]]] = 0 for i in range(1, m): l = sorted_lms[i - 1] r = sorted_lms[i] end_l = lms[lms_map[l] + 1] if lms_map[l] + 1 < m else n end_r = lms[lms_map[r] + 1] if lms_map[r] + 1 < m else n same = True if end_l - l != end_r - r: same = False else: while l < end_l: if s[l] != s[r]: break l += 1 r += 1 if l == n or s[l] != s[r]: same = False if not same: rec_upper += 1 rec_s[lms_map[sorted_lms[i]]] = rec_upper rec_sa = sa_is(rec_s, rec_upper) for i in range(m): sorted_lms[i] = lms[rec_sa[i]] sa = [-1] * n buf = sum_s.copy() for d in sorted_lms: if d == n: continue sa[buf[s[d]]] = d buf[s[d]] += 1 buf = sum_l.copy() sa[buf[s[n - 1]]] = n - 1 buf[s[n - 1]] += 1 for i in range(n): v = sa[i] if v >= 1 and not ls[v - 1]: sa[buf[s[v - 1]]] = v - 1 buf[s[v - 1]] += 1 buf = sum_l.copy() for i in range(n-1, -1, -1): v = sa[i] if v >= 1 and ls[v - 1]: buf[s[v - 1] + 1] -= 1 sa[buf[s[v - 1] + 1]] = v - 1 return sa def suffix_array(s, upper=None): """sは文字列または0以上の整数のリスト """ if isinstance(s, str): s = [ord(c) for c in s] if upper is None: upper = max(s) return sa_is(s, upper) def lcp_array(s, sa): n = len(s) rnk = [0] * n for i in range(n): rnk[sa[i]] = i lcp = [0] * (n - 1) h = 0 for i in range(n): if h > 0: h -= 1 if rnk[i] == 0: continue j = sa[rnk[i] - 1] while j + h < n and i + h < n: if s[j + h] != s[i + h]: break h += 1 lcp[rnk[i] - 1] = h return lcp def preproc(l): """リスト中の文字列→ ord()で置き換え それ以外→ ord()の最大値より大きい値で置き換え """ d = {} i = ord("z") + 1000 for ind,item in enumerate(s): if item not in d: if isinstance(item, str) and len(item)==1: d[item] = ord(item) else: d[item] = i i += 1 s[ind] = d[item] n = int(input()) s = [] # tmp = [None]*n tmp = {} for i in range(n): # tmp[i] = len(s) tmp[len(s)] = i s.extend(list(input())) s.append(i) preproc(s) sa = suffix_array(s) rsa = [None]*n for i,item in enumerate(sa): if item in tmp: rsa[tmp[item]] = i lcp = lcp_array(s, sa) ninf = 10**9 op = min sg = SegmentTree(len(lcp), lcp) M,x,d = map(int, input().split()) ans = 0 def sub(i,j): ii,jj = rsa[i], rsa[j] if ii>jj: ii,jj = jj,ii return sg.query(ii,jj) for k in range(1, M+1): i = (x // (n-1)) + 1 j = (x % (n-1)) + 1 if i>j: i,j = j,i else: j += 1 x = (x + d) % (n * (n-1)) # print(i,j) ans += sub(i-1,j-1) print(ans)