import sys input = sys.stdin.buffer.readline def main(): n,k = map(int,input().split()); MOD = pow(10,9)+7 ans = n * (pow(n,k,MOD) - pow(n-1,k,MOD)) print(ans%MOD) if __name__ == '__main__': main()