n, k = map(int, input().split()) mod = 10**9+7 x = k*(k+3)//2 y = k*(k+1)//2 ans = pow(x, n, mod)-pow(y, n, mod) print(ans%mod)