n, k=map(int, input().split()) def solve(x): if x<=k: return x y=solve(x-(x+k-1)//k) q=y//(k-1) r=y%(k-1) if r==0: return q*k else: return q*k+1+r print(solve(n))