import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); if (n == 2 && k == 1) { System.out.println(0); return; } long[][] comb = new long[n + 1][n + 1]; for (int i = 0; i <= n; i++) { for (int j = 0; j <= i; j++) { if (j == 0 || j == i) { comb[i][j] = 1; } else { comb[i][j] = comb[i - 1][j - 1] + comb[i - 1][j]; } } } long total = 1; if (k == 1) { for (int i = 1; i < n - 2 && i <= n - k; i++) { total += comb[n - k][i]; } } else { for (int i = 1; i < n - 1 && i <= n - k; i++) { total += comb[n - k][i]; } } System.out.println(total); } }