結果
問題 | No.515 典型LCP |
ユーザー | shotoyoo |
提出日時 | 2020-09-20 17:29:18 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 6,526 bytes |
コンパイル時間 | 226 ms |
コンパイル使用メモリ | 82,396 KB |
実行使用メモリ | 450,504 KB |
最終ジャッジ日時 | 2024-06-24 10:50:44 |
合計ジャッジ時間 | 4,947 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | TLE | - |
testcase_01 | TLE | - |
testcase_02 | RE | - |
testcase_03 | RE | - |
testcase_04 | RE | - |
testcase_05 | TLE | - |
testcase_06 | TLE | - |
testcase_07 | TLE | - |
testcase_08 | TLE | - |
testcase_09 | RE | - |
testcase_10 | RE | - |
testcase_11 | RE | - |
testcase_12 | RE | - |
testcase_13 | RE | - |
testcase_14 | RE | - |
testcase_15 | TLE | - |
testcase_16 | TLE | - |
ソースコード
import sys input = lambda : sys.stdin.readline().rstrip() sys.setrecursionlimit(max(1000, 10**9)) write = lambda x: sys.stdout.write(x+"\n") from copy import copy def sa_naive(s): n = len(s) sa = list(range(n)) sa.sort(key=lambda x: s[x:]) return sa def sa_doubling(s): n = len(s) sa = list(range(n)) rnk = copy(s) tmp = [0] * n k = 1 while k < n: sa.sort(key=lambda x: (rnk[x], rnk[x + k]) if x + k < n else (rnk[x], -1)) tmp[sa[0]] = 0 for i in range(1, n): tmp[sa[i]] = tmp[sa[i - 1]] if sa[i - 1] + k < n: x = (rnk[sa[i - 1]], rnk[sa[i - 1] + k]) else: x = (rnk[sa[i - 1]], -1) if sa[i] + k < n: y = (rnk[sa[i]], rnk[sa[i] + k]) else: y = (rnk[sa[i]], -1) if x < y: tmp[sa[i]] += 1 k *= 2 tmp, rnk = rnk, tmp return sa def sa_is(s, upper): n = len(s) if n == 0: return [] if n == 1: return [0] if n == 2: if s[0] < s[1]: return [0, 1] else: return [1, 0] if n < 10: return sa_naive(s) if n < 50: return sa_doubling(s) ls = [0] * n for i in range(n-2, -1, -1): ls[i] = ls[i + 1] if s[i] == s[i + 1] else s[i] < s[i + 1] sum_l = [0] * (upper + 1) sum_s = [0] * (upper + 1) for i in range(n): if ls[i]: sum_l[s[i] + 1] += 1 else: sum_s[s[i]] += 1 for i in range(upper): sum_s[i] += sum_l[i] if i < upper: sum_l[i + 1] += sum_s[i] lms_map = [-1] * (n + 1) m = 0 for i in range(1, n): if not ls[i - 1] and ls[i]: lms_map[i] = m m += 1 lms = [] for i in range(1, n): if not ls[i - 1] and ls[i]: lms.append(i) sa = [-1] * n buf = sum_s.copy() for d in lms: if d == n: continue sa[buf[s[d]]] = d buf[s[d]] += 1 buf = sum_l.copy() sa[buf[s[n - 1]]] = n - 1 buf[s[n - 1]] += 1 for i in range(n): v = sa[i] if v >= 1 and not ls[v - 1]: sa[buf[s[v - 1]]] = v - 1 buf[s[v - 1]] += 1 buf = sum_l.copy() for i in range(n-1, -1, -1): v = sa[i] if v >= 1 and ls[v - 1]: buf[s[v - 1] + 1] -= 1 sa[buf[s[v - 1] + 1]] = v - 1 if m: sorted_lms = [] for v in sa: if lms_map[v] != -1: sorted_lms.append(v) rec_s = [0] * m rec_upper = 0 rec_s[lms_map[sorted_lms[0]]] = 0 for i in range(1, m): l = sorted_lms[i - 1] r = sorted_lms[i] end_l = lms[lms_map[l] + 1] if lms_map[l] + 1 < m else n end_r = lms[lms_map[r] + 1] if lms_map[r] + 1 < m else n same = True if end_l - l != end_r - r: same = False else: while l < end_l: if s[l] != s[r]: break l += 1 r += 1 if l == n or s[l] != s[r]: same = False if not same: rec_upper += 1 rec_s[lms_map[sorted_lms[i]]] = rec_upper rec_sa = sa_is(rec_s, rec_upper) for i in range(m): sorted_lms[i] = lms[rec_sa[i]] sa = [-1] * n buf = sum_s.copy() for d in sorted_lms: if d == n: continue sa[buf[s[d]]] = d buf[s[d]] += 1 buf = sum_l.copy() sa[buf[s[n - 1]]] = n - 1 buf[s[n - 1]] += 1 for i in range(n): v = sa[i] if v >= 1 and not ls[v - 1]: sa[buf[s[v - 1]]] = v - 1 buf[s[v - 1]] += 1 buf = sum_l.copy() for i in range(n-1, -1, -1): v = sa[i] if v >= 1 and ls[v - 1]: buf[s[v - 1] + 1] -= 1 sa[buf[s[v - 1] + 1]] = v - 1 return sa def suffix_array(s, upper=None): """sは文字列または0以上の整数のリスト """ if isinstance(s, str): s = [ord(c) for c in s] if upper is None: upper = max(s) + 1 return sa_is(s, upper) def lcp_array(s, sa): n = len(s) rnk = [0] * n for i in range(n): rnk[sa[i]] = i lcp = [0] * (n - 1) h = 0 for i in range(n): if h > 0: h -= 1 if rnk[i] == 0: continue j = sa[rnk[i] - 1] while j + h < n and i + h < n: if s[j + h] != s[i + h]: break h += 1 lcp[rnk[i] - 1] = h return lcp def preproc(l): """リスト中の文字列→ ord()で置き換え それ以外→ ord()の最大値より大きい値で置き換え """ d = {} i = ord("z") + 1000 for ind,item in enumerate(s): if item not in d: if isinstance(item, str) and len(item)==1: d[item] = ord(item) else: d[item] = i i += 1 s[ind] = d[item] class SparseTable: op = min def __init__(self, a): n = len(a) t = [] c = 1 l = a while True: t.append(l) c *= 2 if c>n: break v = c//2 l = [self.op(l[i], l[i+v]) for i in range(n-v)] + l[-v:] self.t = t def query(self, l, r): """a[l,r)の最小値 assert l<r and l,r<n """ # t[i][j]: a[j:j+2^i)の最小値 b = (r-l).bit_length()-1 return self.op(self.t[b][l], self.t[b][r-(1<<b)]) n = int(input()) s = [] # tmp = [None]*n tmp = {} for i in range(n): # tmp[i] = len(s) tmp[len(s)] = i s.extend(list(input())) s.append(i) preproc(s) sa = suffix_array(s) rsa = [None]*n for i,item in enumerate(sa): if item in tmp: rsa[tmp[item]] = i lcp = lcp_array(s, sa) st = SparseTable(lcp) assert False M,x,d = map(int, input().split()) ans = 0 def sub(i,j): ii,jj = rsa[i], rsa[j] if ii>jj: ii,jj = jj,ii return st.query(ii,jj) for k in range(1, M+1): i = (x // (n-1)) + 1 j = (x % (n-1)) + 1 if i>j: i,j = j,i else: j += 1 x = (x + d) % (n * (n-1)) # print(i,j) ans += sub(i-1,j-1) print(ans)