n,k = map(int,open(0).read().split()) MOD = 10**9+7 print(n*(pow(n,k,MOD)-pow(n-1,k,MOD))%MOD)