import sys input = lambda : sys.stdin.readline().rstrip() sys.setrecursionlimit(2*10**5+10) write = lambda x: sys.stdout.write(x+"\n") debug = lambda x: sys.stderr.write(x+"\n") writef = lambda x: print("{:.12f}".format(x)) n,k = list(map(int, input().split())) M = 10**9+7 val = pow(((n-1) * pow(n, M-2, M) % M), k, M) ans = (1 - val) * n * pow(n,k,M) % M print(ans%M)