import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        new Main();
    }
    
    public Main() {
        try (Scanner sc = new Scanner(System.in)) {
            int N = sc.nextInt(), p = sc.nextInt();
            long[] fib = new long[N], cumsum = new long[N];
            if (N == 1) {
                System.out.println(0);
                return;
            }
            fib[1] = cumsum[1] = 1;
            final int MOD = 1_000_000_007;
            for (int i = 2;i < N;++ i) {
                fib[i] = (p * fib[i - 1] + fib[i - 2]) % MOD;
                cumsum[i] = (cumsum[i - 1] + fib[i]) % MOD;
            }
            long ans = 0;
            for (int i = 0;i < N;++ i) ans = (ans + fib[i] * cumsum[i]) % MOD;
            System.out.println(ans);
        }
    }
}