import sys input = sys.stdin.readline mod = 998244353 N, p0 = map(int, input().split()) A = list(map(int, input().split())) p = p0 ans = 0 while p <= 10 ** 9: cnt = {} for a in A: aa = a % p if aa not in cnt: cnt[aa] = 0 cnt[aa] += 1 for a in cnt: x = cnt[a] ans += x * (x - 1) // 2 p *= p0 print(ans)