import sys sys.setrecursionlimit(10**7) n, k=map(int, input().split()) def solve(x): if x<=k: return x-1 y=solve(x-(x+k-1)//k) return y//(k-1)*k+1+y%(k-1) print(solve(n)+1)