import java.util.*; public class Main{ static final int MOD = 1000000007; public static void main(String args[]){ Scanner sc = new Scanner(System.in); int N = sc.nextInt(); int P = sc.nextInt(); Long numP = 0l; for(int i = 1; i <= N; i++){ int n = i; while(n % P == 0){ numP++; n /= P; } } long nM = 1; for(int i = 1; i <= N; i++){ nM *= i; nM %= MOD; } char[] S = Long.toBinaryString(nM).toCharArray(); long binM = nM; long ans = 1; for(int i = S.length-1; i >= 0; i--){ // System.out.print(S[i]); if(S[i] == '1'){ ans *= binM; ans %= MOD; } binM *= binM; binM %= MOD; } ans *= numP; ans %= MOD; System.out.println(ans); } }