import java.util.Scanner;

public class Main {
	public static void main(String[] args) throws Exception {
		Scanner sc = new Scanner(System.in);
		int p = sc.nextInt();
		int k = sc.nextInt();
		sc.close();

		int m = 1000000007;
		long dp0 = p + 1;
		long dp1 = p - 1;
		for (int i = 1; i < k; i++) {
			long w0 = dp0 * (p + 1) + dp1 * 2;
			long w1 = dp0 * (p - 1) + dp1 * (2 * p - 2);
			dp0 = w0 % m;
			dp1 = w1 % m;
		}
		System.out.println(dp0);
	}
}