n, p = map(int, input().split()) a = list(map(int, input().split())) a.sort() max_diff = a[-1] - a[0] ans = 0 current_power = p while current_power <= max_diff: from collections import defaultdict mod_count = defaultdict(int) for num in a: mod = num % current_power mod_count[mod] += 1 c_k = 0 for cnt in mod_count.values(): c_k += cnt * (cnt - 1) // 2 ans += c_k next_power = current_power * p if next_power > max_diff: break current_power = next_power print(ans)