n,m=map(int,input().split()) if n=k else 0 print(sum(((-1)**(i%2))*c(m,i)*pow(m-i,n,M) for i in range(m+1))%M)