n, k = map(int,input().split()) t = n * (n + 1) // 2 * k print(t // (n + 1) // k)