import sys read = sys.stdin.buffer.read readline = sys.stdin.buffer.readline readlines = sys.stdin.buffer.readlines N,p = map(int,readline().split()) if N == 1: print(0) exit() MOD = 10 ** 9 + 7 A = [0] * (N+1) A[2] = 1 for n in range(3,N+1): A[n] = (A[n-1] * p + A[n-2]) % MOD S1 = sum(A) S2 = sum(x**2 for x in A) S = (S1**2 + S2) % MOD if S&1: S += MOD S //= 2 print(S)