n, k = map(int, input().split()) s = k * (k + 1) // 2 mod = 10**9 + 7 print((pow(s + k, n, mod) - pow(s, n, mod)) % mod)