import java.util.Arrays; import java.util.LinkedList; import java.util.Scanner; public class Main { public static long gcd(long a, long b){ return b == 0 ? a : gcd(b, a % b); } public static long lcm(long a, long b){ return a / gcd(a, b) * b; } public static final long MOD = 1000000007; public static void main(String[] args) { Scanner sc = new Scanner(System.in); final int N = sc.nextInt(); final long M = sc.nextLong(); long other_fact = 1; for(int i = 1; i <= (N - 2); i++){ other_fact *= i; other_fact %= MOD; } long[] pat = new long[N + 1]; Arrays.fill(pat, N / M); if(M <= N){ pat[(int)(M)]--; } for(long i = 2 * M; i <= N; i += M){ for(long j = i; j <= N; j += i){ pat[(int)(j)]--; } } long answer = 0; for(long fst = M; fst <= N; fst += M){ //System.out.println(fst + " : " + pat[(int)(fst)]); answer += (pat[(int)(fst)] * other_fact) % MOD; answer %= MOD; } System.out.println(answer); } }