import sys input=lambda: sys.stdin.readline().rstrip() n,k,x,y=map(int,input().split()) A=sorted([(int(i)-2)//k+1 for i in input().split()],reverse=True) m=y//x if m>=n: print(x*sum(A)) else: print(A[m]*y+x*(sum(A[:m])-A[m]*m))