def solve(): x,a = map(int,input().split()) inf = 10**20 ans = inf for i in range(1,a+1): num = (a+i-1)//i over = num*i-a count = (num**2+x)*(i-over) if over: count += ((num-1)**2+x)*over ans = min(ans,count) print(ans) t = int(input()) for _ in range(t): solve()