import java.util.*; public class Main { static final int MOD = 1000000007; public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int p = sc.nextInt(); long[] fibo = new long[n]; if (n > 1) { fibo[1] = 1; } long[] sums = new long[n]; sums[1] = 1; for (int i = 2; i < n; i++) { fibo[i] = (p * fibo[i - 1] + fibo[i - 2]) % MOD; sums[i] = (sums[i - 1] + fibo[i]) % MOD; } long ans = 0; for (int i = 1; i < n; i++) { ans += fibo[i] * ((sums[n - 1] - sums[i - 1] + MOD) % MOD) % MOD; ans %= MOD; } System.out.println(ans); } }