import sys input = sys.stdin.readline N, K = map(int, input().split()) mod = 10 ** 9 + 7 print((pow(N, K + 1, mod) - N * pow(N - 1, K, mod)) % mod)