import java.util.Arrays; import java.util.Scanner; public class Main { public static void main(String[] args) { solver(); } static final long MOD = 1_000_000_007; static void solver() { Scanner sc = new Scanner(System.in); long n = sc.nextLong(); int m = sc.nextInt(); long[] mCk = nth_Pascal_triangle(m); long ans = 0; ans += pow(m, n); for (int i = 1; i <= m - 1; i++) { ans = ans + mCk[i] * pow(m - i, n) % MOD * (i % 2 == 0 ? 1 : -1); if (ans < 0) ans += MOD; ans%=MOD; } System.out.println(ans); } static long[] nth_Pascal_triangle(int n) { long[] ans = new long[n + 1]; for (int i = 0; i <= n; i++) { long[] dp = Arrays.copyOf(ans, ans.length); for (int j = 0; j < i + 1; j++) { if (j == 0) ans[j] = 1; else if (j == i) ans[j] = 1; else { ans[j] = (dp[j - 1] + dp[j]) % MOD; } } } return ans; } static long pow(long a, long n) { long A = a; long ans = 1; while (n >= 1) { if (n % 2 == 0) { A = (A * A) % MOD; n /= 2; } else if (n % 2 == 1) { ans = ans * A % MOD; n--; } } return ans; } }