n,m=map(int,input().split()) s = 1 for i in range(2,(n|m)+1): s *= i print(s)