import java.util.Scanner; public class Main { public static void main(String[] args) throws Exception { Scanner sc = new Scanner(System.in); int p = sc.nextInt(); int k = sc.nextInt(); sc.close(); int m = 1000000007; long dp0 = p + 1; long dp1 = p - 1; for (int i = 1; i < k; i++) { long w0 = dp0 * (p + 1) + dp1 * 2; long w1 = dp0 * (p - 1) + dp1 * (2 * p - 2); dp0 = w0 % m; dp1 = w1 % m; } System.out.println(dp0); } }