import sys
input=lambda: sys.stdin.readline().rstrip()
n,k=map(int,input().split())
mod=10**9+7
k1=(k*(k+1)//2)%mod
k2=(k*(k+3)//2)%mod
ans1=1
ans2=1
for i in range(n):
	ans1=(ans1*k1)%mod
	ans2=(ans2*k2)%mod
print((ans2-ans1)%mod)