n, m = map(int, input().split()) if n == 1: ans = m * (m + 1) // 2 ans -= m - 1 if m >= 3: ans -= m - 3 print(ans) exit() if m <= 10: F = [1, 1] while len(F) < m: F.append(F[-1] + F[-2]) ans = 0 for l in range(m): tot = 0 for r in range(l, m): tot += F[r] if tot % n == 0 and tot <= F[m - 1]: ans += 1 print(ans) exit() a = 0 b = 1 cnt = [0] * n tot = 0 lst = [] F = [] while 1: tot = (tot + a) % n lst.append(tot) F.append(a) cnt[tot] += 1 a, b = b, (a + b) % n if a == 0 and b == 1: break C = [0] * n le = len(lst) loop = (m - 1) // le for i in range(n): C[i] = cnt[i] * loop for i in range((m - 1) % le): C[lst[i]] += 1 ans = 0 for c in C: ans += c * (c - 1) // 2 a = F[m % le - 2] b = F[m % le - 1] c = F[m % le] if (a + b) % n == 0: ans += 1 if b % n == 0: ans += 1 if c % n == 0: ans += 1 print(ans)