MOD = 10 ** 9 + 7
n, k = map(int, input().split())
print(n * (pow(n, k, MOD) - pow(n - 1, k, MOD)) % MOD)