n,k=map(int,input().split()) mod=10**9+7 # 全部同時に計算する。 kn=pow(k,n,mod) ksum=k*(k+1)//2 ksum%=mod invk=pow(k,mod-2,mod) dp=kn dp0=kn for i in range(n): # s[i]をかける # s[i]をかけない dp=dp*(1+1*ksum*invk)%mod dp0=dp0*ksum*invk%mod print((dp-dp0)%mod)