n,m=map(int,input().split()) if n==1: if m==1: print(2) elif m==2: print(2) else: print((m-2)*(m-3)//2+m+1) exit() if m==1: print(0) exit() if m==2: print(0) exit() if m==3: if n==2: print(2) else: print(0) exit() loop=[0,1] while True: loop.append((loop[-1]+loop[-2]+1)%n) if loop[-1]==loop[-2]==0: break loop=loop[1:] cnt=[0]*n cnt[0]=1 l=len(loop) p=(m-2)//l for i in loop: cnt[i]+=p for i in range((m-2)%l): cnt[loop[i]]+=1 ans=0 last=(loop[(m-2)%l]-loop[(m-3)%l])%n if last==0: ans+=1 last=(loop[(m-2)%l]-loop[(m-4)%l])%n if last==0: ans+=1 last=(loop[(m-1)%l]-loop[(m-2)%l])%n if last==0: ans+=1 for i in cnt: ans+=i*(i-1)//2 print(ans)