n,d=map(int,input().split()) a=list(range(n)) p=0 c=0 while 1: a.remove(p) p+=d p%=n if p in a: c+=1 else: break print(c)