import sys
input=lambda: sys.stdin.readline().rstrip()
n,k,x,y=map(int,input().split())
A=sorted([(int(i)-2)//k+1 for i in input().split()],reverse=True)
m=y//x
if m>=n:
  print(x*sum(A))
else:
  print(A[m]*y+x*(sum(A[:m])-A[m]*m))