結果
問題 | No.1300 Sum of Inversions |
ユーザー | shotoyoo |
提出日時 | 2020-11-27 22:40:53 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 540 ms / 2,000 ms |
コード長 | 2,340 bytes |
コンパイル時間 | 459 ms |
コンパイル使用メモリ | 82,640 KB |
実行使用メモリ | 145,524 KB |
最終ジャッジ日時 | 2024-07-26 19:02:12 |
合計ジャッジ時間 | 15,151 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 34 |
ソースコード
import sys input = lambda : sys.stdin.readline().rstrip() sys.setrecursionlimit(max(1000, 10**9)) write = lambda x: sys.stdout.write(x+"\n") class BIT: ### BIT binary def __init__(self, n, values=None): self.bit = [0]*(n+1) self.n = n if values is not None: for i,v in enumerate(values): self.add(i,v) #a1 ~ aiまでの和 O(logn) def query(self,i): res = 0 while i > 0: res += self.bit[i] i -= i&(-i) return res #ai += x(logN) def add(self,i,x): i += 1 if i==0: raise RuntimeError while i <= len(self.bit)-1: self.bit[i] += x i += i&(-i) def check(self): l = [] prv = 0 for i in range(1,self.n+1): val = self.query(i) l.append(val-prv) prv = val print(" ".join(map(str, l))) def index(self, v): """a0,...,aiの和がv以上になる最小のindexを求める 存在しないとき配列サイズを返す """ if v <= 0: return 0 x = 0 r = 1 while r < n: r = r << 1; ll = r while ll>0: if x+ll<n and self.bit[x+ll]<v: v -= self.bit[x+ll] x += ll ll = ll>>1 return x from bisect import bisect_left def press(l): # xs[inds[i]]==l[i]となる xs = sorted(set(l)) inds = [None] * len(l) for i,item in enumerate(l): inds[i] = bisect_left(xs, item) return xs, inds M = 998244353 n = int(input()) a = list(map(int, input().split())) xs, inds = press(a) m = len(xs) bit = BIT(m) v0 = [] for i,v in enumerate(inds): v0.append(i - bit.query(v+1)) bit.add(v, 1) bit = BIT(m) v1 = [] for i,v in enumerate(inds[::-1]): v1.append(bit.query(v)) bit.add(v, 1) v1 = v1[::-1] bit = BIT(m) c1 = [] for i in range(n-1,-1,-1): v = inds[i] c1.append(bit.query(v)) bit.add(v, v1[i]) c1 = c1[::-1] bit = BIT(m) c0 = [] tmp = 0 for i in range(n): v = inds[i] c0.append(tmp-bit.query(v+1)) bit.add(v, v0[i]) tmp += v0[i] ans = 0 for i in range(n): ans += a[i]*v0[i]*v1[i] # print(ans) # for i in range(n): ans += a[i]*(c0[i] + c1[i]) ans %= M print(ans)