n, k = map(int, input().split())
s = k * (k + 1) // 2
mod = 10**9 + 7
print((pow(s + k, n, mod) - pow(s, n, mod)) % mod)