結果
| 問題 | No.1300 Sum of Inversions |
| コンテスト | |
| ユーザー |
回転
|
| 提出日時 | 2026-05-22 18:11:37 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 6,881 bytes |
| 記録 | |
| コンパイル時間 | 145 ms |
| コンパイル使用メモリ | 85,120 KB |
| 実行使用メモリ | 252,932 KB |
| 最終ジャッジ日時 | 2026-05-22 18:11:55 |
| 合計ジャッジ時間 | 4,357 ms |
|
ジャッジサーバーID (参考情報) |
judge2_1 / judge3_0 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | TLE * 1 -- * 33 |
ソースコード
class WaveletMatrix:
def __init__(self, V):
self.n = len(V)
self.lg = max(V).bit_length() if V else 0
if self.lg == 0:
self.lg = 1
self.B = [0] * self.lg
self.zeros = [0] * self.lg
self.accs = []
self.original_V = V
curr_V = list(V)
# ビット列の構築
for i in range(self.lg):
bit = self.lg - 1 - i
b = 0
zero_arr = []
one_arr = []
z_app = zero_arr.append
o_app = one_arr.append
mask = 1 << bit
for j, v in enumerate(curr_V):
if v & mask:
b |= (1 << j)
o_app(v)
else:
z_app(v)
self.B[i] = b
self.zeros[i] = len(zero_arr)
curr_V = zero_arr + one_arr
acc = [0] * (self.n + 1)
for j in range(self.n):
acc[j+1] = acc[j] + curr_V[j]
self.accs.append(acc)
acc_orig = [0] * (self.n + 1)
for j in range(self.n):
acc_orig[j+1] = acc_orig[j] + V[j]
self.accs_orig = acc_orig
def access(self, i):
return self.original_V[i]
def rank(self, r, x):
if (x >> self.lg) & 1: return 0
B = self.B
zeros = self.zeros
lg = self.lg
for i in range(lg):
bit = (x >> (lg - 1 - i)) & 1
b_val = B[i]
r1 = (b_val & ((1 << r) - 1)).bit_count()
if bit:
r = zeros[i] + r1
else:
r = r - r1
return r
def rank_range(self, l, r, x):
return self.rank(r, x) - self.rank(l, x)
def quantile(self, l, r, k):
res = 0
B = self.B
zeros = self.zeros
lg = self.lg
for i in range(lg):
b_val = B[i]
l1 = (b_val & ((1 << l) - 1)).bit_count()
r1 = (b_val & ((1 << r) - 1)).bit_count()
ones = r1 - l1
z = (r - l) - ones
if k < z:
l = l - l1
r = r - r1
else:
res |= (1 << (lg - 1 - i))
k -= z
z_cnt = zeros[i]
l = z_cnt + l1
r = z_cnt + r1
return res
def _range_freq(self, l, r, x):
if x.bit_length() > self.lg:
return r - l
res = 0
B = self.B
zeros = self.zeros
lg = self.lg
for i in range(lg):
if l == r: break
bit = (x >> (lg - 1 - i)) & 1
b_val = B[i]
# 多倍長整数を用いた O(1) での rank1 計算
l1 = (b_val & ((1 << l) - 1)).bit_count()
r1 = (b_val & ((1 << r) - 1)).bit_count()
l0 = l - l1
r0 = r - r1
if bit:
res += r0 - l0
z = zeros[i]
l = z + l1
r = z + r1
else:
l = l0
r = r0
return res
def range_freq(self, left, right, lower, upper):
return self._range_freq(left, right, upper) - self._range_freq(left, right, lower)
def prev_value(self, left, right, upper):
cnt = self._range_freq(left, right, upper)
return self.quantile(left, right, cnt - 1) if cnt > 0 else None
def next_value(self, left, right, lower):
cnt = self._range_freq(left, right, lower)
return self.quantile(left, right, cnt) if cnt < right - left else None
def _range_sum(self, l, r, x):
if self.lg < x.bit_length():
return self.accs_orig[r] - self.accs_orig[l]
res = 0
B = self.B
zeros = self.zeros
accs = self.accs
lg = self.lg
for i in range(lg):
if l == r: break
bit = (x >> (lg - 1 - i)) & 1
b_val = B[i]
l1 = (b_val & ((1 << l) - 1)).bit_count()
r1 = (b_val & ((1 << r) - 1)).bit_count()
l0 = l - l1
r0 = r - r1
if bit:
res += accs[i][r0] - accs[i][l0]
z = zeros[i]
l = z + l1
r = z + r1
else:
l = l0
r = r0
return res
def range_sum(self, left, right, lower, upper):
return self._range_sum(left, right, upper) - self._range_sum(left, right, lower)
def _build_distinct_wm(self):
P = [0] * self.n
last_pos = {}
for i, v in enumerate(self.original_V):
P[i] = last_pos.get(v, -1) + 1
last_pos[v] = i
self._distinct_wm = WaveletMatrix(P)
def range_distinct(self, left, right):
"""
区間 [left, right) に含まれる要素の種類数を返す
"""
if not hasattr(self, "_distinct_wm"):
self._build_distinct_wm()
return self._distinct_wm.range_freq(left, right, 0, left + 1)
def bottom_k_sum(self, l, r, k):
"""
区間 [l, r) の中で小さい方から k 個の要素の和を返す
"""
if k <= 0: return 0
if k >= r - l: return self.accs_orig[r] - self.accs_orig[l]
res = 0
val = 0
B = self.B
zeros = self.zeros
accs = self.accs
lg = self.lg
for i in range(lg):
b_val = B[i]
l1 = (b_val & ((1 << l) - 1)).bit_count()
r1 = (b_val & ((1 << r) - 1)).bit_count()
ones = r1 - l1
z = (r - l) - ones
l0 = l - l1
r0 = r - r1
if k <= z:
l = l0
r = r0
else:
res += accs[i][r0] - accs[i][l0]
k -= z
val |= (1 << (lg - 1 - i))
z_cnt = zeros[i]
l = z_cnt + l1
r = z_cnt + r1
res += k * val
return res
def top_k_sum(self, l, r, k):
"""
区間 [l, r) の中で大きい方から k 個の要素の和を返す
"""
if k <= 0: return 0
length = r - l
if k >= length: return self.accs_orig[r] - self.accs_orig[l]
total_sum = self.accs_orig[r] - self.accs_orig[l]
return total_sum - self.bottom_k_sum(l, r, length - k)
MOD = 998244353
N = int(input())
A = list(map(int,input().split()))
MAX = max(A)
WM = WaveletMatrix(A)
ans = 0
for i in range(1,N-1):
ans += WM.range_sum(0,i,A[i]+1,MAX+1) * WM.range_freq(i+1,N,0,A[i])
ans += A[i] * WM.range_freq(0,i,A[i]+1,MAX+1) * WM.range_freq(i+1,N,0,A[i])
ans += WM.range_freq(0,i,A[i]+1,MAX+1) * WM.range_sum(i+1,N,0,A[i])
ans %= MOD
print(ans)
回転