import java.util.Arrays; import java.util.LinkedList; import java.util.Scanner; public class Main { public static long gcd(long a, long b){ return b == 0 ? a : gcd(b, a % b); } public static long lcm(long a, long b){ return a / gcd(a, b) * b; } public static final long MOD = 1000000007; public static void main(String[] args) { Scanner sc = new Scanner(System.in); final long N = sc.nextLong(); final long M = sc.nextLong(); long other_fact = 1; for(int i = 1; i <= (N - 2); i++){ other_fact *= i; other_fact %= MOD; } long answer = 0; for(long fst = M; fst <= N; fst += M){ for(long snd = fst + M; snd <= N; snd += M){ answer += (other_fact * 2) % MOD; answer %= MOD; } } System.out.println(answer); } }