import sys

int1 = lambda x: int(x) - 1

# input = lambda: sys.stdin.buffer.readline()
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
i1 = lambda: int1(input())
mi = lambda: map(int, input().split())
mi1 = lambda: map(int1, input().split())
li = lambda: list(mi())
li1 = lambda: list(mi1())
lli = lambda n: [li() for _ in range(n)]

INF = float("inf")
mod = int(1e9 + 7)
# mod = 998244353

n, m = mi()
a = li()
b = li()
s = []
t = []
for i in range(n):
    if i:
        s.append(1)
    s += [0] * a[i]
for j in range(m):
    if j:
        t.append(1)
    t += [0] * b[j]

x = len(s)
y = len(t)
dp = [[INF] * (y + 1) for i in range(x + 1)]
dp[0][0] = 0
for i in range(x + 1):
    for j in range(y + 1):
        if i < x and j < y:
            cost = 0 if s[i] == t[j] else 1
            dp[i + 1][j + 1] = min(dp[i + 1][j + 1], dp[i][j] + cost)
        if i < x:
            dp[i + 1][j] = min(dp[i + 1][j], dp[i][j] + 1)
        if j < y:
            dp[i][j + 1] = min(dp[i][j + 1], dp[i][j] + 1)
print(dp[-1][-1])