from sys import stdin input = stdin.readline from functools import lru_cache @lru_cache(maxsize=1000) def next(): x = input().rstrip() return(x) def nextint(): x = int(input()) return(x) def nextfloat(): x = float(input()) return(x) def nextlist(): x = list(input().rstrip().split()) return(x) def nextintlist(): x = list(map(int, input().rstrip().split())) return(x) def nextfloatlist(): x = list(map(float, input().rstrip().split())) return(x) def nextlist2(): n = int(input()) # nは入力回数 x = [input().rstrip() for _ in range(n)] return(x) def nextintlist2(): n = int(input()) # nは入力回数 x = [int(input()) for _ in range(n)] return(x) def nextfloatlist2(): n = int(input()) # nは入力回数 x = [float(input()) for _ in range(n)] return(x) def nextdoublelist(): n = int(input()) # nは入力回数 x = [list(input().rstrip().split()) for _ in range(n)] return(x) def nextdoubleintlist(): n = int(input()) # nは入力回数 x = [list(map(int, input().rstrip().split())) for _ in range(n)] return(x) def nextdoublefloatlist(): n = int(input()) # nは入力回数 x = [list(map(float, input().rstrip().split())) for _ in range(n)] return(x) s = nextintlist() x = 1 y2 = 0 for i in range(s[0]): x *= (i + 1) y1 = x % s[1] x1 = x while x1 % s[1] == 0: x1 //= s[1] y2 += 1 a1 = y2 * (x ** x) print(int(a1 % 1000000007))