n, m = map(int, input().split())
mod = 10**9 + 7

if n < m:
    print(0)
    exit()

ans = 0
comb = 1
for i in range(m):
    ans += (-1)**i * comb * pow(m - i, n, mod)
    ans %= mod
    comb *= m - i
    comb %= mod
    comb *= pow(i + 1, mod - 2, mod)
    comb % mod
print(ans)