結果
問題 | No.2361 Many String Compare Queries |
ユーザー |
![]() |
提出日時 | 2025-03-26 15:55:13 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 6,790 bytes |
コンパイル時間 | 176 ms |
コンパイル使用メモリ | 82,444 KB |
実行使用メモリ | 563,744 KB |
最終ジャッジ日時 | 2025-03-26 15:56:08 |
合計ジャッジ時間 | 4,944 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 4 WA * 2 MLE * 1 -- * 7 |
ソースコード
def main():import sysinput = sys.stdin.readdata = input().split()idx = 0N = int(data[idx])Q = int(data[idx+1])idx +=2S = data[idx]idx +=1queries = []for _ in range(Q):L = int(data[idx])-1 # 0-basedR = int(data[idx+1])-1queries.append( (L, R) )idx +=2# Suffix Automaton implementationclass State:def __init__(self):self.next = {}self.link = -1self.len = 0self.cnt = 0 # number of occurrencesself.sum_i = 0 # sum of starting indicessa = [State()]last = 0for i, c in enumerate(S):p = lastcurr = len(sa)sa.append(State())sa[curr].len = sa[p].len +1sa[curr].cnt = 1sa[curr].sum_i = iwhile p >=0 and c not in sa[p].next:sa[p].next[c] = currp = sa[p].linkif p == -1:sa[curr].link = 0else:q = sa[p].next[c]if sa[p].len +1 == sa[q].len:sa[curr].link = qelse:clone = len(sa)sa.append(State())sa[clone].len = sa[p].len +1sa[clone].next = sa[q].next.copy()sa[clone].link = sa[q].linkwhile p >=0 and sa[p].next[c] == q:sa[p].next[c] = clonep = sa[p].linksa[q].link = clonesa[curr].link = clonelast = curr# Reverse the links to compute cnt and sum_i using topological sortV = len(sa)size = [0]*Vsum_i = [0]*Vstates = sorted(range(V), key=lambda x: -sa[x].len)for v in states:if sa[v].link >=0:sa[sa[v].link].cnt += sa[v].cntsa[sa[v].link].sum_i += sa[v].sum_ifor v in range(V):size[v] = sa[v].cntsum_i[v] = sa[v].sum_i# Precompute for each state, transitions sorted and cumulative countsfrom collections import defaultdicttrans = [defaultdict(int) for _ in range(V)]for v in range(V):for c, to in sa[v].next.items():trans[v][c] = to# For each state, precompute sorted list of transitionssorted_trans = [ [] for _ in range(V) ]for v in range(V):sorted_c = sorted(trans[v].keys())sorted_trans[v] = sorted_c# Precompute prefix sums for transitionsprefix_sum = [{} for _ in range(V)]for v in range(V):sorted_c = sorted_trans[v]cnt = 0sum_cnt = 0sum_total = 0sum_sum_i = 0temp = []for c in sorted_c:to = trans[v][c]cnt += size[to]sum_cnt += size[to]sum_total += size[to]sum_sum_i += sum_i[to]temp.append( (c, sum_cnt, sum_total, sum_sum_i) )prefix_sum[v] = temp# Process each queryfor L, R in queries:len_U = R - L +1if len_U ==0:print(0)continueU = S[L:R+1]current = 0cnt_condition2 = 0cnt_condition1 = 0states_m = []max_m = 0valid = Truefor m in range(len_U):c = U[m]if c not in trans[current]:valid = Falsebreaknext_state = trans[current][c]states_m.append( (m+1, current, next_state) )current = next_statemax_m = m+1if valid:states_m.append( (max_m+1, current, None) )else:pass# Compute condition2: sum of min(m_i, len_U-1)# where m_i is the length of the longest prefix of U starting at i# We need to find for each i, the maximum m where S[i..i+m-1] == U[0..m-1]# This is equivalent to the length of the longest prefix of U that is a substring starting at i# But computing this for all i is expensive, so we use the suffix automaton to track the prefixes# Track the states for each prefix of Uprefix_states = []current =0for m in range(len_U):if m >= len(U):breakc = U[m]if c not in sa[current].next:breakcurrent = sa[current].next[c]prefix_states.append( (m+1, current) )sum_condition2 =0for m, s in prefix_states:if m >= len_U:continuesum_condition2 += size[s]# Now, compute condition1sum_condition1 =0current_len =0current_node =0for k in range(len_U):if k >0:c = U[k-1]if c not in sa[current_node].next:breakcurrent_node = sa[current_node].next[c]current_len +=1if current_len != k:breakif k >= len_U:breaknext_char = U[k] if k < len_U else Noneif k < len_U-1:passelse:passtarget_c = U[k] if k < len(U) else Noneif k >= len_U:breakif current_node ==0 and current_len !=k:break# Find all transitions from current_node with c < target_c# Use the precomputed sorted transitions and prefix sumsif target_c is None:continuesorted_c_list = sorted_trans[current_node]low =0high = len(sorted_c_list)left =0while low < high:mid = (low+high)//2if sorted_c_list[mid] < target_c:low = mid +1else:high = midcount_less = lowif count_less ==0:continue# Get the sum of cnt and sum_i for transitions with c < target_cif count_less > len(prefix_sum[current_node]):sum_cnt = prefix_sum[current_node][-1][1] if prefix_sum[current_node] else 0sum_sum_i = prefix_sum[current_node][-1][3] if prefix_sum[current_node] else 0else:if count_less ==0:sum_cnt =0sum_sum_i =0else:sum_cnt = prefix_sum[current_node][count_less-1][1]sum_sum_i = prefix_sum[current_node][count_less-1][3]m = k+1term = (N - m +1) * sum_cnt - sum_sum_isum_condition1 += max(term, 0)total = sum_condition1 + sum_condition2print(total)if __name__ == '__main__':main()