from sys import stdin input = stdin.readline from functools import lru_cache @lru_cache(maxsize=1000) def next(): x = input().rstrip() return(x) def nextint(): x = int(input()) return(x) def nextfloat(): x = float(input()) return(x) def nextlist(): x = list(input().rstrip().split()) return(x) def nextintlist(): x = list(map(int, input().rstrip().split())) return(x) def nextfloatlist(): x = list(map(float, input().rstrip().split())) return(x) def nextlist2(): n = int(input()) # nは入力回数 x = [input().rstrip() for _ in range(n)] return(x) def nextintlist2(): n = int(input()) # nは入力回数 x = [int(input()) for _ in range(n)] return(x) def nextfloatlist2(): n = int(input()) # nは入力回数 x = [float(input()) for _ in range(n)] return(x) def nextdoublelist(): n = int(input()) # nは入力回数 x = [list(input().rstrip().split()) for _ in range(n)] return(x) def nextdoubleintlist(): n = int(input()) # nは入力回数 x = [list(map(int, input().rstrip().split())) for _ in range(n)] return(x) def nextdoublefloatlist(): n = int(input()) # nは入力回数 x = [list(map(float, input().rstrip().split())) for _ in range(n)] return(x) def pow_k(x, n): """ O(log n) """ if n == 0: return 1 K = 1 while n > 1: if n % 2 != 0: K *= x x *= x n //= 2 return K * x s = nextintlist() x = 1 y2 = 0 for i in range(s[0]): x *= (i + 1) if i % s[1] == 0: y2 += 1 xx = pow_k(x,x) a1 = y2 * (xx) print(int(a1 % 1000000007))