n,p = map(int,input().split()) mod = 10**9+7 num = 0 facts = 1 for i in range(1,n+1): now = i while now%p == 0: now //= p num += 1 facts *= i facts %= mod**2 po = pow(facts,facts,mod**2) ans = num * po % mod print(ans)