m = 1000000007

N, K = map(int, input().split())

print((pow(N, K + 1, m) - N * pow(N - 1, K, m)) % m)