N,B=map(int,input().split()) v=[0]*B for x in range(B): a=pow(x,N,B) v[a]+=1 result=0 for i in range(B): for j in range(B): k=(i+j)%B n=v[i]*v[j]*v[k] result+=n print(result)