n,m=map(int,input().split())
co=1
ans=0
while n>0:
    ans+=co
    n=n//m
    co*=2
print(ans)