n, m = map(int, input().split()) f = 1 for i in range(1, n%m+1): f *= i print(f)