import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); long x = sc.nextInt(); long y = sc.nextInt(); int[] arr = new int[n]; for (int i = 0; i < n; i++) { arr[i] = (sc.nextInt() - 1 + k - 1) / k; } Arrays.sort(arr); long sum = arr[n - 1]; long min = arr[n - 1] * y; for (int i = n - 2; i >= 0; i--) { min = Math.min(min, arr[i] * y + (sum - (long)(n - i - 1) * arr[i]) * x); sum += arr[i]; } min = Math.min(min, sum * x); System.out.println(min); } }