mod=10**9+7 N,K=map(int,input().split()) ans=N*(pow(N,K,mod)-pow(N-1,K,mod))%mod print(ans%mod)