import bisect import sys input=sys.stdin.readline def main(args): N, K, bhim, bhmr = map(int,input().split()) A = list(map(lambda x:-(-(int(x)-1)//K), input().split())) A.sort() prev = 0 ans = 0 #print("A :",A) for i in range(N): #print('##############') if A[i] != prev: if bhmr < bhim*(N-i): ans += bhmr*(A[i]-prev) else: #残りはベホイミしたほうがいい break prev = A[i] #print(ans) A[i] = 0 if sum(A) > 0: ans += (sum(A) - (N-i)*prev)*bhim print(ans) if __name__ == '__main__': main(sys.argv[1:])