結果
| 問題 |
No.2327 Inversion Sum
|
| コンテスト | |
| ユーザー |
sepa38
|
| 提出日時 | 2023-05-09 19:39:56 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 298 ms / 2,000 ms |
| コード長 | 2,656 bytes |
| コンパイル時間 | 169 ms |
| コンパイル使用メモリ | 82,096 KB |
| 実行使用メモリ | 91,064 KB |
| 最終ジャッジ日時 | 2024-11-26 07:38:38 |
| 合計ジャッジ時間 | 5,719 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 30 |
ソースコード
class segt:
def __init__(self, n, ele, calc):
self.num = 2 ** (n - 1).bit_length()
self.el = ele
self.data = [ele] * (2 * self.num)
self.calc = calc
def update(self, idx, x):
idx += self.num - 1
self.data[idx] = x
while idx > 0:
idx = (idx - 1) // 2
self.data[idx] = self.calc(self.data[2*idx+1], self.data[2*idx+2])
def renew(self, idx, x):
self.update(idx, self.calc(self.get(idx), x))
def prod(self, left, right):
l = left + self.num
r = right + self.num
res = self.el
while l < r:
if l % 2:
res = self.calc(res, self.data[l-1])
l += 1
if r % 2:
r -= 1
res = self.calc(res, self.data[r-1])
l //= 2
r //= 2
return res
def get(self, idx):
return self.data[idx+self.num-1]
def add(x, y):
return x + y
n, m = map(int, input().split())
pk = [list(map(lambda x: int(x)-1, input().split())) for _ in range(m)]
def solve(n, m, pk):
mod = 998244353
fac = [1]
for i in range(n):
fac.append(fac[i]*(i+1)%mod)
idx = [-1] * n
st1 = segt(n, 0, add)
for p, k in pk:
st1.update(k, 1)
idx[p] = k
st2 = segt(n, 0, add)
st3 = segt(n, 0, add)
s = 0
ans = 0
for i in range(n):
#print(i)
if idx[i] == -1:
ans += s * fac[n-m-1] % mod
#print(ans, end = " ")
cnd = i - st2.prod(0, i)
ans += cnd * ((n - m) * (n - m - 1) // 2) * fac[n-m-2] % mod
#print(ans)
else:
st2.update(i, 1)
st3.update(idx[i], 1)
ans += st3.prod(idx[i]+1, n) * fac[n-m] % mod
#print(ans, end = " ")
cnd = i - st2.prod(0, i)
ans += cnd * ((n - idx[i]) - st1.prod(idx[i], n)) * fac[n-m-1] % mod
s += idx[i] - st1.prod(0, idx[i])
#print(ans, s)
ans %= mod
return ans
from itertools import permutations
def naive(n, m, pk):
P = [-1] * n
for pi, ki in pk:
P[ki] = pi
ans = 0
for p in permutations(range(n)):
flg = 0
for i in range(n):
if P[i] != -1 and p[i] != P[i]:
flg = 1
if flg:
continue
for i in range(n):
for j in range(i+1, n):
ans += p[i] > p[j]
return ans
print(solve(n, m, pk))
exit()
print(naive(n, m, pk))
print("=" * 10)
import random
for _ in range(100):
n = random.randint(1, 10)
usedp = [0] * n
usedk = [0] * n
pk = []
for _ in range(random.randint(0, n*2)):
p, k = random.randint(0, n-1), random.randint(0, n-1)
if usedp[p] == 0 and usedk[k] == 0:
pk.append([p, k])
usedp[p] = 1
usedk[k] = 1
m = len(pk)
if solve(n, m, pk) != naive(n, m, pk):
print(n, m)
for p, k in pk:
print(p+1, k+1)
break
sepa38