mod = 10 ** 9 + 7 def comb(n, r): if n < r:return 0 if n < 0 or r < 0:return 0 return fa[n] * fi[r] % mod * fi[n - r] % mod n, m = map(int, input().split()) fa = [1] * (m + 1) fi = [1] * (m + 1) for i in range(1, m + 1): fa[i] = fa[i - 1] * i % mod fi[i] = pow(fa[i], mod - 2, mod) ans = 0 for i in range(m + 1): cnt = comb(m, i) * pow(i, n, mod) if (m - i) & 1: ans -= cnt else: ans += cnt print(ans % mod)