h, w = map(int, input().split()) mp = [list(map(int, input().split())) for _ in range(h)] while True: g = True while True: f = True for i in range(h-1): for j in range(w-1): arr = [mp[i][j], mp[i][j+1], mp[i+1][j], mp[i+1][j+1]] if arr.count(1) == 4: mp[i][j] = 0 mp[i][j+1] = 0 mp[i+1][j] = 0 mp[i+1][j+1] = 0 f = False g = False if f: break while True: f = True for i in range(h-1): for j in range(w-1): arr = [mp[i][j], mp[i][j+1], mp[i+1][j], mp[i+1][j+1]] if arr.count(1) == 3: mp[i][j] = int(not mp[i][j]) mp[i][j+1] = int(not mp[i][j+1]) mp[i+1][j] = int(not mp[i+1][j]) mp[i+1][j+1] = int(not mp[i+1][j+1]) f = False g = False if f: break while True: f = True for i in range(h-1): for j in range(w-1): arr = [mp[i][j], mp[i][j+1], mp[i+1][j], mp[i+1][j+1]] if arr == [0, 0, 1, 1]: mp[i][j] = int(not mp[i][j]) mp[i][j+1] = int(not mp[i][j+1]) mp[i+1][j] = int(not mp[i+1][j]) mp[i+1][j+1] = int(not mp[i+1][j+1]) f = False g = False if f: break while True: f = True for i in range(h-1): for j in range(w-1): arr = [mp[i][j], mp[i][j+1], mp[i+1][j], mp[i+1][j+1]] if arr == [0, 1, 0, 1]: mp[i][j] = int(not mp[i][j]) mp[i][j+1] = int(not mp[i][j+1]) mp[i+1][j] = int(not mp[i+1][j]) mp[i+1][j+1] = int(not mp[i+1][j+1]) f = False g = False if f: break if g: break ans = 0 for x in mp: ans += sum(x) print(ans)