import sys readline = sys.stdin.readline ns = lambda: readline().rstrip() ni = lambda: int(readline().rstrip()) nm = lambda: map(int, readline().split()) nl = lambda: list(map(int, readline().split())) def solve(): n, k, x, y = nm() a = [0] + [j-1 for j in nl()] a.sort() f = min(y // x, n) ans = 0 c = a[-f-1] g = ((c - 1) // k + 1) ans = y * g + x * sum(max(j - g * k - 1, -1)//k + 1 for j in a[-f:]) * (f > 0) print(ans) return solve()