import java.util.*;

public class Main{
	static final int MOD = 1000000007;

	public static void main(String args[]){
		Scanner sc = new Scanner(System.in);
		int N = sc.nextInt();
		int P = sc.nextInt();

		if(N < P){
			System.out.println(0);
			return;
		}

		Long numP = 0l;
		for(int i = 1; i <= N; i++){
			int n = i;
			while(n % P == 0){
				numP++;
				n /= P;
			}
		}

		long nM = 1;
		for(int i = 1; i <= N; i++){
			nM *= i;
			nM %= MOD;
		}

		char[] S = Long.toBinaryString(nM).toCharArray();
		long binM = nM;
		long ans = 1;

		for(int i = S.length-1; i >= 0; i--){
//			System.out.print(S[i]);
			if(S[i] == '1'){
				ans *= binM;
				ans %= MOD;
			}
			binM *= binM;
			binM %= MOD;
		}

		ans *= numP;
		ans %= MOD;

		System.out.println(ans);
	}
}