n, k = map(int, input().split()) mod = 10**9+7 ans = pow(n, k, mod)-pow(n-1, k, mod) ans *= n ans %= mod print(ans)