import sys input=lambda: sys.stdin.readline().rstrip() n,k=map(int,input().split()) mod=10**9+7 k1=(k*(k+1)//2)%mod k2=(k*(k+3)//2)%mod ans1=1 ans2=1 for i in range(n): ans1=(ans1*k1)%mod ans2=(ans2*k2)%mod print((ans2-ans1)%mod)