def solve(): from collections import defaultdict N,P=map(int,input().split()) A=list(map(int,input().split())) A_max=max(A) Ans=0; v=1; Q=P while Q<=A_max: E=defaultdict(int) for a in A: E[a%Q]+=1 for y in E.values(): Ans+=y*(y-1)//2 v+=1; Q*=P return Ans #================================================== print(solve())