結果
問題 | No.2362 Inversion Number of Mod of Linear |
ユーザー |
|
提出日時 | 2025-04-20 08:24:57 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 499 ms / 2,000 ms |
コード長 | 1,826 bytes |
コンパイル時間 | 578 ms |
コンパイル使用メモリ | 82,596 KB |
実行使用メモリ | 83,628 KB |
最終ジャッジ日時 | 2025-04-20 08:25:01 |
合計ジャッジ時間 | 3,537 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 8 |
ソースコード
def fsum(N, M, A, B): # sum_{i=0}^{N-1} floor((Ai+B) / M) # all non-negative add = A // M * N * (N - 1) // 2 + B // M * N A %= M B %= M if A == 0: return add max_val = (A * N + B) // M return add + max_val * N - fsum(max_val, A, M, M - B + A - 1) # i * floor((Ai+B) / M) # floor((Ai+B) / M) ^2 def psum(p, N): if p == 0: return N if p == 1: return N *(N-1)//2 if p == 2: return (N-1)*(N)*(2*N-1)//6 def fsum_general(N, M, A, B): # sum floor[(Ai+B)/M] # sum floor[(Ai+B)/M]*i # sum floor[(Ai+B)/M]^2 if A >= M: # print("TEST1") ans01, ans11, ans02 = fsum_general(N, M, A % M, B) q = A // M ans02 += q ** 2 * psum(2, N) ans02 += 2 * q * ans11 ans01 += q * psum(1, N) ans11 += q * psum(2, N) return ans01, ans11, ans02 if B >= M: # print("TEST2") ans01, ans11, ans02 = fsum_general(N, M, A, B % M) q = B // M ans02 += q ** 2 * N ans02 += 2 * q * ans01 ans01 += q * N ans11 += q * psum(1, N) return ans01, ans11, ans02 if A == 0: return 0, 0, 0 max_val = (A * N + B) // M ans01, ans11, ans02 = fsum_general(max_val, A, M, M - B + A - 1) return max_val * N - ans01, max_val * psum(1, N) - (ans02 - ans01) // 2, (2 * psum(1, max_val) + max_val) * N - 2 * ans11 - ans01 # ans01, ans11, ans02 = 0, 0, 0 # for i in range(N): # ans01 += (A*i + B) // M # ans11 += (A*i + B) // M * i # ans02 += ((A*i + B) // M) * ((A*i + B) // M) # return ans01, ans11, ans02 def solve(): N, M, X, Y = map(int, input().split()) ans = 0 # for i in range(N): # for j in range(i+1, N): # ans += (j*X+Y)//M-(i*X+Y)//M-(j-i)*X//M # print(ans) ans01, ans11, ans02 = fsum_general(N, M, X, Y) ans = ans11 - ((N-1)*ans01-ans11) ans01, ans11, ans02 = fsum_general(N-1, M, X, X) ans -= (N-1)*ans01 ans += ans11 print(ans) T = int(input()) for _ in range(T): solve()