結果
| 問題 |
No.515 典型LCP
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2020-09-20 16:49:31 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
RE
|
| 実行時間 | - |
| コード長 | 10,147 bytes |
| コンパイル時間 | 227 ms |
| コンパイル使用メモリ | 81,792 KB |
| 実行使用メモリ | 198,920 KB |
| 最終ジャッジ日時 | 2024-06-24 10:47:48 |
| 合計ジャッジ時間 | 7,344 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | RE * 15 |
ソースコード
import sys
input = lambda : sys.stdin.readline().rstrip()
sys.setrecursionlimit(max(1000, 10**9))
write = lambda x: sys.stdout.write(x+"\n")
### セグメント木
class SegmentTree:
def __init__(self, n, a=None):
"""初期化
num : n以上の最小の2のべき乗
"""
num = 1
while num<=n:
num *= 2
self.num = num
self.seg = [ninf] * (2*self.num-1)
if a is not None:
# O(n)で初期化
assert len(a)==n
for i in range(n):
self.seg[num-1+i] = a[i]
for k in range(num-2, -1, -1):
self.seg[k] = op(self.seg[2*k+1], self.seg[2*k+2])
def update(self,i,x):
"""update(i,x):Aiをxに更新する
"""
k = i+(self.num-1)
self.seg[k] = x
k = (k-1)//2
while k >= 0:
self.seg[k] = op(self.seg[2*k+1], self.seg[2*k+2])
k = (k-1)//2
def query(self,a,b):
k = 0
l = 0
r = self.num
q = [(k,l,r)]
ans = ninf
# 重なる区間を深さ優先探索
while q:
k,l,r = q.pop()
if r<=a or b<=l:
pass
elif a<=l and r<=b:
ans = op(ans, self.seg[k])
else:
q.append((2*k+1,l,(l+r)//2))
q.append((2*k+2,(l+r)//2,r))
return ans
def find_right(self,a,b,x=None,f=None):
"""[a,b)で値がx以上のインデックスの最大
存在しない場合-1を返す
"""
if f is None:
f = lambda y: y>=x
k = 0
l = 0
r = self.num
q = [(k,l,r,True)] # 行きがけかどうか
ans = -1
while q:
k,l,r,flg = q.pop()
if flg:
if not f(self.seg[k]) or r<=a or b<=l: # 条件を満たせない or 区間が重複しない
pass
elif k>=self.num-1: # 自身が葉
ans = max(ans, k - (self.num-1))
return ans
else:
# 左への探索を予約
q.append((2*k+1,l,(l+r)//2,False))
# 右への探索
q.append((2*k+2,(l+r)//2,r,True))
else:
if ans>=0:
return ans
q.append((k,l,r,True))
return ans
def find_left(self,a,b,x=None, f=None):
"""[a,b)で値がx以上のインデックス(0,1,...,self.num-1)の最小
条件を満たすものが存在しないとき、self.numを返す
"""
if f is None:
f = lambda y: y>=x
k = 0
l = 0
r = self.num
q = [(k,l,r,True)] # 行きがけかどうか
ans = self.num
while q:
k,l,r,flg = q.pop()
if flg:
if not f(self.seg[k]) or r<=a or b<=l: # x以上を満たせない or 区間が重複しない
continue
elif k>=self.num-1: # 自身が葉
ans = min(ans, k - (self.num-1))
return ans
else:
# 右への探索を予約
q.append((2*k+2,(l+r)//2,r,False))
# 左への探索
q.append((2*k+1,l,(l+r)//2,True))
else:
if ans<self.num:
return ans
q.append((k,l,r,True))
return ans
def query_index(self,a,b,k=0,l=0,r=None):
"""query(a,b,0,0,num):[a,b)の最大値
最大値を与えるインデックスも返す
"""
if r is None:
r = self.num
if r <= a or b <= l:
return (ninf, None)
elif a <= l and r <= b:
return (self.seg[k], self._index(k))
else:
return op(self.query_index(a,b,2*k+1,l,(l+r)//2),self.query_index(a,b,2*k+2,(l+r)//2,r))
def _index(self, k):
if k>=self.num:
return k - (self.num-1)
else:
if self.seg[2*k+1]>=self.seg[2*k+2]:
return self._index(2*k+1)
else:
return self._index(2*k+2)
from copy import copy
def sa_naive(s):
n = len(s)
sa = list(range(n))
sa.sort(key=lambda x: s[x:])
return sa
def sa_doubling(s):
n = len(s)
sa = list(range(n))
rnk = copy(s)
tmp = [0] * n
k = 1
while k < n:
sa.sort(key=lambda x: (rnk[x], rnk[x + k])
if x + k < n else (rnk[x], -1))
tmp[sa[0]] = 0
for i in range(1, n):
tmp[sa[i]] = tmp[sa[i - 1]]
if sa[i - 1] + k < n:
x = (rnk[sa[i - 1]], rnk[sa[i - 1] + k])
else:
x = (rnk[sa[i - 1]], -1)
if sa[i] + k < n:
y = (rnk[sa[i]], rnk[sa[i] + k])
else:
y = (rnk[sa[i]], -1)
if x < y:
tmp[sa[i]] += 1
k *= 2
tmp, rnk = rnk, tmp
return sa
def sa_is(s, upper):
n = len(s)
if n == 0:
return []
if n == 1:
return [0]
if n == 2:
if s[0] < s[1]:
return [0, 1]
else:
return [1, 0]
if n < 10:
return sa_naive(s)
if n < 50:
return sa_doubling(s)
ls = [0] * n
for i in range(n-2, -1, -1):
ls[i] = ls[i + 1] if s[i] == s[i + 1] else s[i] < s[i + 1]
sum_l = [0] * (upper + 1)
sum_s = [0] * (upper + 1)
for i in range(n):
if ls[i]:
sum_l[s[i] + 1] += 1
else:
sum_s[s[i]] += 1
for i in range(upper):
sum_s[i] += sum_l[i]
if i < upper:
sum_l[i + 1] += sum_s[i]
lms_map = [-1] * (n + 1)
m = 0
for i in range(1, n):
if not ls[i - 1] and ls[i]:
lms_map[i] = m
m += 1
lms = []
for i in range(1, n):
if not ls[i - 1] and ls[i]:
lms.append(i)
sa = [-1] * n
buf = sum_s.copy()
for d in lms:
if d == n:
continue
sa[buf[s[d]]] = d
buf[s[d]] += 1
buf = sum_l.copy()
sa[buf[s[n - 1]]] = n - 1
buf[s[n - 1]] += 1
for i in range(n):
v = sa[i]
if v >= 1 and not ls[v - 1]:
sa[buf[s[v - 1]]] = v - 1
buf[s[v - 1]] += 1
buf = sum_l.copy()
for i in range(n-1, -1, -1):
v = sa[i]
if v >= 1 and ls[v - 1]:
buf[s[v - 1] + 1] -= 1
sa[buf[s[v - 1] + 1]] = v - 1
if m:
sorted_lms = []
for v in sa:
if lms_map[v] != -1:
sorted_lms.append(v)
rec_s = [0] * m
rec_upper = 0
rec_s[lms_map[sorted_lms[0]]] = 0
for i in range(1, m):
l = sorted_lms[i - 1]
r = sorted_lms[i]
end_l = lms[lms_map[l] + 1] if lms_map[l] + 1 < m else n
end_r = lms[lms_map[r] + 1] if lms_map[r] + 1 < m else n
same = True
if end_l - l != end_r - r:
same = False
else:
while l < end_l:
if s[l] != s[r]:
break
l += 1
r += 1
if l == n or s[l] != s[r]:
same = False
if not same:
rec_upper += 1
rec_s[lms_map[sorted_lms[i]]] = rec_upper
rec_sa = sa_is(rec_s, rec_upper)
for i in range(m):
sorted_lms[i] = lms[rec_sa[i]]
sa = [-1] * n
buf = sum_s.copy()
for d in sorted_lms:
if d == n:
continue
sa[buf[s[d]]] = d
buf[s[d]] += 1
buf = sum_l.copy()
sa[buf[s[n - 1]]] = n - 1
buf[s[n - 1]] += 1
for i in range(n):
v = sa[i]
if v >= 1 and not ls[v - 1]:
sa[buf[s[v - 1]]] = v - 1
buf[s[v - 1]] += 1
buf = sum_l.copy()
for i in range(n-1, -1, -1):
v = sa[i]
if v >= 1 and ls[v - 1]:
buf[s[v - 1] + 1] -= 1
sa[buf[s[v - 1] + 1]] = v - 1
return sa
def suffix_array(s, upper=255):
if isinstance(s, str):
s = [ord(c) for c in s]
return sa_is(s, upper)
def lcp_array(s, sa):
n = len(s)
rnk = [0] * n
for i in range(n):
rnk[sa[i]] = i
lcp = [0] * (n - 1)
h = 0
for i in range(n):
if h > 0:
h -= 1
if rnk[i] == 0:
continue
j = sa[rnk[i] - 1]
while j + h < n and i + h < n:
if s[j + h] != s[i + h]:
break
h += 1
lcp[rnk[i] - 1] = h
return lcp
def preproc(l):
"""リスト中の文字列→ ord()で置き換え
それ以外→ ord()の最大値より大きい値で置き換え
"""
d = {}
i = ord("z") + 1000
for ind,item in enumerate(s):
if item not in d:
if isinstance(item, str) and len(item)==1:
d[item] = ord(item)
else:
d[item] = i
i += 1
s[ind] = d[item]
n = int(input())
s = []
# tmp = [None]*n
tmp = {}
for i in range(n):
# tmp[i] = len(s)
tmp[len(s)] = i
s.extend(list(input()))
s.append(i)
preproc(s)
sa = suffix_array(s)
rsa = [None]*n
for i,item in enumerate(sa):
if item in tmp:
rsa[tmp[item]] = i
lcp = lcp_array(s, sa)
ninf = 10**9
op = min
sg = SegmentTree(len(lcp), lcp)
M,x,d = map(int, input().split())
ans = 0
def sub(i,j):
ii,jj = rsa[i], rsa[j]
if ii>jj:
ii,jj = jj,ii
return sg.query(ii,jj)
for k in range(1, M+1):
i = (x // (n-1)) + 1
j = (x % (n-1)) + 1
if i>j:
i,j = j,i
else:
j += 1
x = (x + d) % (n * (n-1))
# print(i,j)
ans += sub(i-1,j-1)
print(ans)
# for k in 1 .. M
# i[k] = (x / (n - 1)) + 1
# j[k] = (x % (n - 1)) + 1
# if (i[k] > j[k])
# swap(i[k], j[k])
# else
# j[k] = j[k] + 1
# end
# x = (x + d) % (n * (n - 1))