n,k=map(int,input().split()) mod=1000000007 print((pow(n,k,mod)-pow(n-1,k,mod))*n%mod)