N,M = map(int,input().split()) P = 10 ** 9 + 7 C = N + 5 fact = [1] * C fact_inv = [1] * C for i in range(2,C): fact[i] = fact[i-1] * i % P fact_inv[-1] = pow(fact[-1],P-2,P) for i in range(C-2,1,-1): fact_inv[i] = fact_inv[i+1] * (i+1) % P def comb(n,k): return fact[n] * fact_inv[k] % P * fact_inv[n-k] % P ans = pow(M,N,P) for i in range(1,M): if i % 2 == 1: c = -1 else: c = 1 tmp = comb(M,M-i) * pow(M-i,N,P) % P ans = (ans + c * tmp) % P print(ans)