from sys import stdin def main(): input = lambda: stdin.readline()[:-1] N, K, X, Y = map(int, input().split()) A = list(map(int, input().split())) A.sort() hp, mp = 0, 0 for a in A: a -= hp if a <= 1: N -= 1 continue n = (a + K - 1) // K if a - (n - 1) * K <= 1: n -= 1 if N * X <= Y: mp += n * X N -= 1 else: hp += n * K mp += n * Y N -= 1 print(mp) main()