n, k = map(int, input().split()) mod = 10 ** 9 + 7 print((pow(k * (k + 3) // 2, n, mod) - pow(k * (k + 1) // 2, n, mod)) % mod)