import sys input = sys.stdin.readline from collections import * N, S = map(int, input().split()) N += 1 fact = [1] for i in range(1, S): fact.append(fact[-1]*i) p = [] for i in range(S): for j in range(S): if j not in p: if N-fact[S-i-1]>0: N -= fact[S-i-1] else: p.append(j) break pi = [(p[i], i) for i in range(S)] pi.sort() rev_p = [i for p, i in pi] ans = 0 s = set() for i in range(S): c = 0 for j in range(rev_p[i]): if j not in s: c += 1 ans += c*fact[S-i-1] s.add(rev_p[i]) print(ans)