import java.util.*; public class Main { static final int MOD = 1000000007; public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int p = sc.nextInt(); long prepre = 0; long pre = 1; long sum = 1; long total = 1; if (n == 1) { System.out.println(0); return; } else if (n == 1) { System.out.println(1); return; } for (int i = 3; i <= n; i++) { long x = pre * p + prepre; x %= MOD; sum += x; sum %= MOD; total += sum * x; total %= MOD; prepre = pre; pre = x; } System.out.println(total); } }