import sys readline=sys.stdin.readline N,K=map(int,readline().split()) mod=10**9+7 ans=(pow(N,K,mod)-pow(N-1,K,mod))*N%mod print(ans)