N,K = map(int, input().split()) P = 10**9+7 print(N * (pow(N,K,P) - pow(N-1,K,P)) % P)