mod = 10**9 + 7 n,k = map(int,input().split()) print((n * (pow(n,k,mod) - pow(n-1,k,mod)))%mod)