n, m = map(int, input().split()) if n == 1: # (m-2)*(m-1)/2 + add a_{m} and a_{m-1} + a{m-2} if m == 1: assert False; print(1) elif m == 2: print(2) else: print((m-2)*(m-1)//2 + 3) exit() fib = [1, 1, 2 % n] a = [1, 2%n, 4%n] while not (fib[-3] == 0 and fib[-1] == fib[-2] == 1): fib.append((fib[-1] + fib[-2]) % n) a.append((a[-1] + fib[-1]) % n) fib.pop(); fib.pop() a.pop(); a.pop() l = len(a) size = max(0, m - 2) count = [0] * n count[0] += 1 for i in range(size%l): count[a[i]] += 1 for i in range(l): count[a[i]] += size // l ans = 0 for i in range(n): ans += count[i] * (count[i] - 1) // 2 if fib[(m-1)%len(fib)] == 0: ans += 1 if m >= 2 and fib[(m-2)%len(fib)] == 0: ans += 1 if m >= 3 and fib[(m-1)%len(fib)] == 0: ans += 1 print(ans)