import sys input = sys.stdin.readline N,K=map(int,input().split()) x=min(K,N) print((N+(N-x+1))*x//2+1)