n, k = map(int, input().split()) if k != 1: print(n - 1) else: print((n - 2) * (n - 1))