m = 1000000007 N, K = map(int, input().split()) print(N - pow(N - 1, K, m))