結果

問題 No.2332 Make a Sequence
ユーザー gew1fw
提出日時 2025-06-12 14:01:22
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,646 bytes
コンパイル時間 611 ms
コンパイル使用メモリ 82,240 KB
実行使用メモリ 133,556 KB
最終ジャッジ日時 2025-06-12 14:03:11
合計ジャッジ時間 54,287 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 58 TLE * 3
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 10**18 + 3
BASE = 911382629

def main():
    sys.setrecursionlimit(1 << 25)
    N, M = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))
    B = list(map(int, sys.stdin.readline().split()))
    C = list(map(int, sys.stdin.readline().split()))
    
    if M == 0:
        print(0)
        return
    
    # Precompute prefix hashes for A and B
    prefix_hash_A = [0] * (N + 1)
    power_A = [1] * (N + 1)
    for i in range(N):
        prefix_hash_A[i+1] = (prefix_hash_A[i] * BASE + A[i]) % MOD
        power_A[i+1] = (power_A[i] * BASE) % MOD
    
    prefix_hash_B = [0] * (M + 1)
    power_B = [1] * (M + 1)
    for i in range(M):
        prefix_hash_B[i+1] = (prefix_hash_B[i] * BASE + B[i]) % MOD
        power_B[i+1] = (power_B[i] * BASE) % MOD
    
    def get_hash_A(l, r):
        # hash of A[0..r-1] (length r)
        if r == 0:
            return 0
        return (prefix_hash_A[r] - prefix_hash_A[0] * power_A[r]) % MOD
    
    def get_hash_B(l, r):
        # hash of B[l..r-1]
        if l >= r:
            return 0
        res = (prefix_hash_B[r] - prefix_hash_B[l] * power_B[r - l]) % MOD
        return res
    
    max_len = [0] * M
    for j in range(M):
        low = 1
        high = min(N, M - j)
        best = 0
        while low <= high:
            mid = (low + high) // 2
            hash_A = get_hash_A(0, mid)
            hash_B = get_hash_B(j, j + mid)
            if hash_A == hash_B:
                best = mid
                low = mid + 1
            else:
                high = mid - 1
        max_len[j] = best
    
    # Now, solve the DP with line segments
    # We need to find the minimal cost to reach M
    # Using a segment tree that can handle range line insertions and point queries
    class Line:
        __slots__ = ['a', 'b']
        def __init__(self, a, b):
            self.a = a
            self.b = b
        def get(self, x):
            return self.a * x + self.b
    
    INF = 1 << 60
    size = 1
    while size < M + 2:
        size <<= 1
    data = [None] * (2 * size)
    
    def update(l, r, line, node=1, node_l=0, node_r=None):
        if node_r is None:
            node_r = size - 1
        if r < node_l or node_r < l:
            return
        if l <= node_l and node_r <= r:
            if data[node] is None:
                data[node] = line
                return
            current_line = data[node]
            m = (node_l + node_r) // 2
            val_current = current_line.get(m)
            val_new = line.get(m)
            if val_new < val_current:
                data[node] = line
            else:
                if current_line.get(node_l) <= line.get(node_l) and current_line.get(node_r) <= line.get(node_r):
                    return
                if line.get(node_l) < current_line.get(node_l) or line.get(node_r) < current_line.get(node_r):
                    update(l, r, line, 2*node, node_l, (node_l+node_r)//2)
                    update(l, r, line, 2*node+1, (node_l+node_r)//2 +1, node_r)
            return
        update(l, r, line, 2*node, node_l, (node_l + node_r) // 2)
        update(l, r, line, 2*node+1, (node_l + node_r) // 2 + 1, node_r)
    
    def query(x):
        res = INF
        node = 1
        node_l = 0
        node_r = size - 1
        while True:
            if data[node]:
                res = min(res, data[node].get(x))
            if node_l == node_r:
                break
            mid = (node_l + node_r) // 2
            if x <= mid:
                node = 2 * node
                node_r = mid
            else:
                node = 2 * node + 1
                node_l = mid + 1
        return res
    
    # Initialize the segment tree
    for i in range(2 * size):
        data[i] = None
    
    # Initial state: dp[0] = 0
    if max_len[0] > 0:
        a = C[0]
        b = 0 - C[0] * 0
        l = 1
        r = max_len[0]
        if l <= r:
            line = Line(a, b)
            update(l, r, line)
    
    dp = [INF] * (M + 1)
    dp[0] = 0
    
    for j in range(1, M + 1):
        current_cost = query(j)
        if current_cost == INF:
            continue
        dp[j] = current_cost
        if j == M:
            break
        # Insert the line for j
        if max_len[j] == 0:
            continue
        a = C[j]
        b = dp[j] - C[j] * j
        l = j + 1
        r = j + max_len[j]
        if l > M:
            continue
        r = min(r, M)
        line = Line(a, b)
        update(l, r, line)
    
    if dp[M] == INF:
        print(-1)
    else:
        print(dp[M])
    
if __name__ == "__main__":
    main()
0