import java.util.*; import java.io.*; public class Main { static FastScanner sc = new FastScanner(System.in); static PrintWriter pw = new PrintWriter(System.out); static StringBuilder sb = new StringBuilder(); static long mod = (long) 1e9 + 7; public static void main(String[] args) throws Exception { solve(); pw.flush(); } public static void solve() { int N = sc.nextInt(); int K = sc.nextInt(); long rem = rep2(N,K,mod)-rep2(N-1,K,mod); if(rem < 0) rem += mod; rem *= N; pw.println(rem % mod); } private static long rep2(long b, long n, long mod){ if(n == 0) return 1; long bn = rep2(b,n/2,mod); if(n % 2 == 0){ return (bn*bn)%mod; }else{ return (bn*bn)%mod*b%mod; } } static class GeekInteger { public static void save_sort(int[] array) { shuffle(array); Arrays.sort(array); } public static int[] shuffle(int[] array) { int n = array.length; Random random = new Random(); for (int i = 0, j; i < n; i++) { j = i + random.nextInt(n - i); int randomElement = array[j]; array[j] = array[i]; array[i] = randomElement; } return array; } public static void save_sort(long[] array) { shuffle(array); Arrays.sort(array); } public static long[] shuffle(long[] array) { int n = array.length; Random random = new Random(); for (int i = 0, j; i < n; i++) { j = i + random.nextInt(n - i); long randomElement = array[j]; array[j] = array[i]; array[i] = randomElement; } return array; } } } class Combination { final static long mod = (long) 1e9 + 7; private static long[] fact, ifact; public Combination(int n) { fact = new long[n + 1]; ifact = new long[n + 1]; fact[0] = 1; long ln = n; for (long i = 1; i <= ln; ++i) { int ii = (int) i; fact[ii] = fact[ii - 1] % mod * i % mod; } ifact[n] = pow(fact[n], this.mod - 2); for (int i = n; i >= 1; --i) { int ii = (int) i; ifact[ii - 1] = ifact[ii] % mod * i % mod; } } public static long comb(int n, int k) { if (k < 0 || k > n) return 0; return fact[n] % mod * ifact[k] % mod * ifact[n - k] % mod; } public static long perm(int n, int k) { return comb(n, k) * fact[k] % mod; } public static long pow(long a, long b) { long ret = 1; long tmp = a; while (b > 0) { if ((b & 1) == 1) { ret = (ret * tmp) % mod; } tmp = (tmp * tmp) % mod; b = b >> 1; } return ret; } } class FastScanner { private BufferedReader reader = null; private StringTokenizer tokenizer = null; public FastScanner(InputStream in) { reader = new BufferedReader(new InputStreamReader(in)); tokenizer = null; } public FastScanner(FileReader in) { reader = new BufferedReader(in); tokenizer = null; } public String next() { if (tokenizer == null || !tokenizer.hasMoreTokens()) { try { tokenizer = new StringTokenizer(reader.readLine()); } catch (IOException e) { throw new RuntimeException(e); } } return tokenizer.nextToken(); } public String nextLine() { if (tokenizer == null || !tokenizer.hasMoreTokens()) { try { return reader.readLine(); } catch (IOException e) { throw new RuntimeException(e); } } return tokenizer.nextToken("\n"); } public long nextLong() { return Long.parseLong(next()); } public int nextInt() { return Integer.parseInt(next()); } public double nextDouble() { return Double.parseDouble(next()); } public String[] nextArray(int n) { String[] a = new String[n]; for (int i = 0; i < n; i++) a[i] = next(); return a; } public int[] nextIntArray(int n) { int[] a = new int[n]; for (int i = 0; i < n; i++) a[i] = nextInt(); return a; } public long[] nextLongArray(int n) { long[] a = new long[n]; for (int i = 0; i < n; i++) a[i] = nextLong(); return a; } }