n,k=map(int,input().split()) mod = n ans = set() cnt=0 for i in range(k+1): temp=0 temp+=i temp-=n-i temp%=n if temp in ans: break ans.add(temp) cnt+=1 print(cnt)