N,K=map(int,input().split()) mod=10**9+7 X=pow(N-1,K,mod) Y=pow(N,K,mod) print((Y-X)*N%mod)