M,K=map(int,input().split()) mod=10**9+7 print(((2*pow(2*M,K,mod)+pow(M-1,K+1,mod))*pow(M+1,mod-2,mod))%mod)