n,m=map(int,input().split()) co=1 ans=1 while n>1: n=n//m co*=2 ans+=co print(ans)