n,d=map(int,input().split()) s=set(range(n)) p=1 for i in range(n): if p in s:s.discard(p) else:i-=1;break p=(p+d)%n print(i)