X,K = map(int,input().split()) mod = 10**9+7 Q = pow(K,5*10**8+1,mod-1) print(pow(X,Q,mod))