from sys import stdin
def main():
    #入力
    readline=stdin.readline
    P,K=map(int,readline().split())

    mod=10**9+7
    dp=[[0,0] for _ in range(K+1)]
    dp[0][0]=1
    for i in range(1,K+1):
        #+
        dp[i][0]+=dp[i-1][0]
        dp[i][0]+=dp[i-1][1]
        dp[i][1]+=(P-1)*dp[i-1][0]
        dp[i][1]+=(P-1)*dp[i-1][1]

        #*
        dp[i][0]+=P*dp[i-1][0]
        dp[i][0]+=dp[i-1][1]
        dp[i][1]+=(P-1)*dp[i-1][1]

        dp[i][0]%=mod
        dp[i][1]%=mod
    print(dp[K][0])
if __name__=="__main__":
    main()