t = int(input()) n = int(input()) c = list(map(int, input().split())) v = list(map(int, input().split())) for i in range(n): x = v[i] // 2 while x > 0: c.append(c[i]) v.append(x) x //= 2 m = len(c) dp = [[-float('inf')] * (t + 1) for _ in range(m + 1)] dp[0][0] = 0 for i in range(m): for j in range(t + 1): dp[i + 1][j] = max(dp[i + 1][j], dp[i][j]) if j + c[i] <= t: dp[i + 1][j + c[i]] = max(dp[i + 1][j + c[i]], dp[i][j] + v[i]) print(max(dp[-1]))