N, p = map(int, input().split()) mod = 10**9 + 7 a1, a2 = 0, 1 acc, ans = 1, 1 if N <= 2: print(N-1) exit() for i in range(2, N): a1, a2 = a2, (a2*p + a1) % mod acc = (acc + a2) % mod ans = (ans + a2 * acc) % mod print(ans)