N, M = map(int, input().split()) print((pow(M, N, 1000000007) - M * pow(M - 1, N, 1000000007)) % 1000000007)