from sys import stdin def main(): #入力 readline=stdin.readline P,K=map(int,readline().split()) mod=10**9+7 dp=[[0,0] for _ in range(K+1)] dp[0][0]=1 for i in range(1,K+1): #+ dp[i][0]+=dp[i-1][0] dp[i][0]+=dp[i-1][1] dp[i][1]+=(P-1)*dp[i-1][0] dp[i][1]+=(P-1)*dp[i-1][1] #* dp[i][0]+=P*dp[i-1][0] dp[i][0]+=dp[i-1][1] dp[i][1]+=(P-1)*dp[i-1][1] dp[i][0]%=mod dp[i][1]%=mod print(dp[K][0]) if __name__=="__main__": main()