N,K=map(int,input().split()) mod=10**9+7 print((1-pow(N-1,K,mod)*pow(pow(N,mod-2,mod),K,mod))*N*pow(N,K,mod)%mod)