def main(): n, b = list(map(int, input().split())) A = [pow(i, n, b) for i in range(b)] ans = 0 B = [0] * b for a in A: B[a] += 1 for i in range(b): for j in range(b): ans += B[i]*B[j]*B[(i+j)%b] return ans print(main())