N, K = map(int, input().split()) mod = 10 ** 9 + 7 print(N * (pow(N, K, mod) - pow(N - 1, K, mod)) % mod)