n,m=map(int,input().split()) M=10**9+7 fa=[1,1] fb=[1,1] for i in range(2,n+1): fa+=[fa[-1]*i%M] fb+=[fb[-1]*(M//i)*fb[M%i]*fa[M%i-1]*(-1)%M] c=lambda n,k:fa[n]*fb[k]*fb[n-k]%M if n>=k else 0 a=0 for i in range(m+1): a+=c(m,i)*pow(m-i,n,M)*((-1)**(i%2)) a%=M print(a)