import sys
input = sys.stdin.readline
mod = 998244353

N, p0 = map(int, input().split())
A = list(map(int, input().split()))

p = p0
ans = 0
while p <= 10 ** 9:
    cnt = {}
    for a in A:
        aa = a % p
        if aa not in cnt:
            cnt[aa] = 0
        cnt[aa] += 1
    for a in cnt:
        x = cnt[a]
        ans += x * (x - 1) // 2
    p *= p0
print(ans)