n, k = map(int, input().split()) mod = 10**9 + 7 ans = n * (pow(n, k, mod) - pow(n - 1, k, mod)) % mod print(ans)