#yuki978 import sys n,p=map(int,raw_input().split()) mod=10**9+7 a1,a2=0,1 res=1 temp=1 if n<=2: print(n-1) sys.exit() for i in range(2,n): a1,a2=a2,(a2*p+a1)%mod temp=(temp+a2)%mod res=(res+a2*temp)%mod print(res)