import sys def input(): return sys.stdin.readline().rstrip('\n') sys.setrecursionlimit(10 ** 6) def main(): n, k = map(int, input().split()) if n < k: print(-1) return arr = [[] for _ in range(n)] for i in range(n - 1): a, b = map(int, input().split()) a, b = a - 1, b - 1 arr[a].append(b) arr[b].append(a) z = [] def dfs(a, p, h): z.append(h) for b in arr[a]: if b != p: dfs(b, a, h + 1) dfs(0, None, 0) z.sort() print(sum(z[:k])) if __name__ == '__main__': main()