N,B=map(int,input().split()) C=[0]*B L=range(B) for x in L:C[pow(x,N,B)]+=1 print(sum(sum(C[n]*C[m]*C[(n+m)%B]for n in L)for m in L))