import java.util.Scanner; public class Main { public static void main(String[] args) { final int MOD = 1000000007; Scanner scanner = new Scanner(System.in); int h = scanner.nextInt(); int w = scanner.nextInt(); int[] div_by = new int[w]; int[] div_to = new int[w]; for (int i = 2; i < w; i++) if (div_by[i] == 0) for (int j = 2; j <= (w - 1) / i; j++) { div_by[i * j] = i; div_to[i * j] = j; } long res = ((long) h * (w - 1) + (long) w * (h - 1)) % MOD; for (int i = 1; i < w; i++) { int[] facts = new int[30]; int n_fact = 0; for (int cur = i; ; ) { if (div_by[cur] == 0) { if (cur > 1) facts[n_fact++] = cur; break; } facts[n_fact++] = div_by[cur]; cur = div_to[cur]; } // facts should be sorted in descending order here if (n_fact != 0) { int head = 0; for (int j = 1; j < n_fact; j++) if (facts[j] != facts[head]) facts[++head] = facts[j]; n_fact = head + 1; } int[] prod = new int[1 << n_fact]; for (int j = 0; j < n_fact; j++) prod[1 << j] = facts[j]; prod[0] = 1; long so_num = 0, so_sum = 0; for (int j = 0; j < 1 << n_fact; j++) { int cur = prod[j] = prod[j & j - 1] * prod[j & ~(j - 1)]; int num = (h - 1) / cur; long sum = (long) num * (num + 1) / 2 * cur; if ((Integer.bitCount(j) & 1) != 0) { so_num -= num; so_sum -= sum; } else { so_num += num; so_sum += sum; } } res = (res + (long)(2 * (w - i)) * ((so_num * h - so_sum) % MOD)) % MOD; } System.out.println(res); } }