import java.util.Scanner; public class Main { public static void main(String[] args) { new Main(); } public Main() { try (Scanner sc = new Scanner(System.in)) { int N = sc.nextInt(), p = sc.nextInt(); long[] fib = new long[N], cumsum = new long[N]; if (N == 1) { System.out.println(0); return; } fib[1] = cumsum[1] = 1; final int MOD = 1_000_000_007; for (int i = 2;i < N;++ i) { fib[i] = (p * fib[i - 1] + fib[i - 2]) % MOD; cumsum[i] = (cumsum[i - 1] + fib[i]) % MOD; } long ans = 0; for (int i = 0;i < N;++ i) ans = (ans + fib[i] * cumsum[i]) % MOD; System.out.println(ans); } } }