N, M = map(int, input().split()) def solve(N, M): if N == 1: ans = M if M >= 3: ans = (M - 2) * (M - 1) // 2 + 3 return ans if M <= 2: return 0 fib = [0, 1] for i in range(2, M + 1): fib.append((fib[-1] + fib[-2]) % N) if fib[0] == fib[-2] and fib[1] == fib[-1]: break fib.pop() fib.pop() Fib = fib[:] n = len(fib) R = [0 for i in range(N)] for i in range(n - 1): Fib[i + 1] = (Fib[i + 1] + Fib[i]) % N for r in Fib: R[r] += 1 x = (M - 1) // n ans = 0 # R <= M - 2 for i in range(N): r = R[i] R[i] *= x ans += r * r * (x * (x - 1) // 2) ans += x * (r * (r - 1) // 2) for i in range(M - 1 - x * n): r = Fib[i] ans += R[r] R[r] += 1 fib.append(0) fib.append(1) # fib_M = fib_(M - 1) + fib_(M - 2) if fib[M - x * n + 1] % N == 0: ans += 2 # fib_(M - 1) if fib[M - x * n] % N == 0: ans += 1 return ans print(solve(N, M))