from sys import stdin input=lambda :stdin.readline()[:-1] from collections import defaultdict, deque n,p=map(int,input().split()) a=list(map(int,input().split())) ans=0 for i in range(1,30): mod=p**i cnt=defaultdict(int) for j in a: ans+=cnt[j%mod] cnt[j%mod]+=1 print(ans)