import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); int x = sc.nextInt(); int y = sc.nextInt(); PriorityQueue queue = new PriorityQueue<>(); for (int i = 0; i < n; i++) { queue.add(sc.nextInt()); } int base = 1; long total = 0; while (queue.size() > 0) { int z = queue.peek(); int count = (z - base + k - 1) / k; if (queue.size() > y / x) { base += count * k; total += (long)count * y; } else { total += (long)count * x; } queue.poll(); } System.out.println(total); } }