N,K=map(int,input().split()) mod=10**9+7 X=K*(K+1)//2+K print((pow(X,N,mod)-pow(K*(K+1)//2,N,mod))%mod)