結果

問題 No.2362 Inversion Number of Mod of Linear
ユーザー Benjamin Qi
提出日時 2025-04-20 08:24:57
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 499 ms / 2,000 ms
コード長 1,826 bytes
コンパイル時間 578 ms
コンパイル使用メモリ 82,596 KB
実行使用メモリ 83,628 KB
最終ジャッジ日時 2025-04-20 08:25:01
合計ジャッジ時間 3,537 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

def fsum(N, M, A, B):
	# sum_{i=0}^{N-1} floor((Ai+B) / M)
	# all non-negative
	add = A // M * N * (N - 1) // 2 + B // M * N
	A %= M
	B %= M
	if A == 0:
		return add
	max_val = (A * N + B) // M
	return add + max_val * N - fsum(max_val, A, M, M - B + A - 1)


# i * floor((Ai+B) / M)
# floor((Ai+B) / M) ^2

def psum(p, N):
	if p == 0:
		return N
	if p == 1:
		return N *(N-1)//2
	if p == 2:
		return (N-1)*(N)*(2*N-1)//6

def fsum_general(N, M, A, B):
	# sum floor[(Ai+B)/M]
	# sum floor[(Ai+B)/M]*i
	# sum floor[(Ai+B)/M]^2

	if A >= M:
		# print("TEST1")
		ans01, ans11, ans02 = fsum_general(N, M, A % M, B)
		q = A // M
		ans02 += q ** 2 * psum(2, N)
		ans02 += 2 * q * ans11
		ans01 += q * psum(1, N)
		ans11 += q * psum(2, N)
		return ans01, ans11, ans02

	if B >= M:
		# print("TEST2")
		ans01, ans11, ans02 = fsum_general(N, M, A, B % M)
		q = B // M
		ans02 += q ** 2 * N
		ans02 += 2 * q * ans01
		ans01 += q * N
		ans11 += q * psum(1, N)
		return ans01, ans11, ans02

	if A == 0:
		return 0, 0, 0

	max_val = (A * N + B) // M
	ans01, ans11, ans02 = fsum_general(max_val, A, M, M - B + A - 1)
	return max_val * N - ans01, max_val * psum(1, N) - (ans02 - ans01) // 2, (2 * psum(1, max_val) + max_val) * N - 2 * ans11 - ans01

	# ans01, ans11, ans02 = 0, 0, 0
	# for i in range(N):
	# 	ans01 += (A*i + B) // M
	# 	ans11 += (A*i + B) // M * i
	# 	ans02 += ((A*i + B) // M) * ((A*i + B) // M)
	# return ans01, ans11, ans02

def solve():
	N, M, X, Y = map(int, input().split())
	ans = 0
	# for i in range(N): 
	# 	for j in range(i+1, N):
	# 		ans += (j*X+Y)//M-(i*X+Y)//M-(j-i)*X//M
	# print(ans)
	ans01, ans11, ans02 = fsum_general(N, M, X, Y)
	ans = ans11 - ((N-1)*ans01-ans11)
	ans01, ans11, ans02 = fsum_general(N-1, M, X, X)
	ans -= (N-1)*ans01
	ans += ans11
	print(ans)


T = int(input())
for _ in range(T):
	solve()
0