import sys input=lambda: sys.stdin.readline().rstrip() p,k=map(int,input().split()) a,b=1,0 mod=10**9+7 for i in range(k): a,b=a*(p+1)+b*2,a*(p-1)+b*2*(p-1) if a>mod: a%=mod if b>mod: b%=mod print(a)