結果
| 問題 |
No.877 Range ReLU Query
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2019-09-06 22:48:58 |
| 言語 | Java (openjdk 23) |
| 結果 |
AC
|
| 実行時間 | 959 ms / 2,000 ms |
| コード長 | 7,909 bytes |
| コンパイル時間 | 5,635 ms |
| コンパイル使用メモリ | 81,792 KB |
| 実行使用メモリ | 62,748 KB |
| 最終ジャッジ日時 | 2024-11-08 10:09:53 |
| 合計ジャッジ時間 | 15,746 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 20 |
ソースコード
import java.io.*;
import java.util.*;
public class Main_yukicoder877 {
private static Scanner sc;
private static Printer pr;
private static void solve() {
int n = sc.nextInt();
int q = sc.nextInt();
int[] a = sc.nextIntArray(n);
PriorityQueue<Pair> pq = new PriorityQueue<>(Collections.reverseOrder());
for (int i = 0; i < n; i++) {
pq.add(new Pair(a[i], i));
}
int[][] qlrx = sc.nextIntArrays(4, q);
List<Pair> queries = new ArrayList<>(q);
for (int i = 0; i < q; i++) {
int x = qlrx[3][i];
queries.add(new Pair(x, i));
}
Collections.sort(queries);
long[] ans = new long[q];
RSQ rsq = new RSQ(n);
RSQ update = new RSQ(n);
for (int i = q - 1; i >= 0; i--) {
Pair qu = queries.get(i);
int x = qu.a;
while (!pq.isEmpty() && pq.peek().a >= x) {
Pair e = pq.remove();
rsq.add(e.b, e.a);
update.add(e.b, 1);
}
int l = qlrx[1][qu.b] - 1;
int r = qlrx[2][qu.b] - 1;
ans[qu.b] = rsq.query(l, r + 1) - update.query(l, r + 1) * (long)x;
}
for (long e : ans) {
pr.println(e);
}
}
/**
* Segment Tree による Range Sum Query
*/
static class RSQ {
long[] st;
int n;
/**
* n 個の要素を持つSegmentTreeを生成する
*
* @param n 要素数
*/
RSQ(int n) {
this.n = 1;
while (this.n < n) {
this.n *= 2;
}
st = new long[2 * this.n - 1];
}
/**
* i 番目の要素に x を足す
*
* @param i 対象となる要素(0-indexed)
* @param x 追加する値
*/
void add(int i, int x) {
i = n - 1 + i;
st[i] += x;
while (i > 0) {
i = (i - 1) / 2;
st[i] += x;
}
}
/**
* 指定した区間の合計値を返す
*
* @param a 区間の端点。[a,b)(0-indexed)
* @param b 区間の端点。[a,b)(0-indexed)
* @return 指定された区間の合計値
*/
long query(int a, int b) {
return query(a, b, 0, 0, n);
}
private long query(int a, int b, int i, int l, int r) {
if (a >= r || b <= l) {
return 0;
}
if (a <= l && b >= r) {
return st[i];
}
return query(a, b, i * 2 + 1, l, (l + r) / 2) + query(a, b, i * 2 + 2, (l + r) / 2, r);
}
}
static class Pair implements Comparable<Pair> {
int a;
int b;
Pair(int a, int b) {
this.a = a;
this.b = b;
}
@Override
public int compareTo(Pair o) {
int result = Integer.compare(a, o.a);
if (result == 0) {
result = Integer.compare(b, o.b);
}
return result;
}
@Override
public int hashCode() {
final int prime = 31;
int result = a;
result = prime * result + b;
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (!(obj instanceof Pair))
return false;
Pair o = (Pair) obj;
return a == o.a && b == o.b;
}
@Override
public String toString() {
// Pair [xxx, xxxx]
StringBuilder stmp = new StringBuilder(32);
stmp.append("Pair [");
stmp.append(a);
stmp.append(',');
stmp.append(' ');
stmp.append(b);
stmp.append(']');
return stmp.toString();
}
}
// ---------------------------------------------------
public static void main(String[] args) {
sc = new Scanner(System.in);
pr = new Printer(System.out);
solve();
pr.close();
sc.close();
}
static class Scanner {
BufferedReader br;
Scanner(InputStream in) {
br = new BufferedReader(new InputStreamReader(in));
}
private boolean isPrintable(int ch) {
return ch >= '!' && ch <= '~';
}
private boolean isCRLF(int ch) {
return ch == '\n' || ch == '\r' || ch == -1;
}
private int nextPrintable() {
try {
int ch;
while (!isPrintable(ch = br.read())) {
if (ch == -1) {
throw new NoSuchElementException();
}
}
return ch;
} catch (IOException e) {
throw new NoSuchElementException();
}
}
String next() {
try {
int ch = nextPrintable();
StringBuilder sb = new StringBuilder();
do {
sb.appendCodePoint(ch);
} while (isPrintable(ch = br.read()));
return sb.toString();
} catch (IOException e) {
throw new NoSuchElementException();
}
}
int nextInt() {
try {
// parseInt from Integer.parseInt()
boolean negative = false;
int res = 0;
int limit = -Integer.MAX_VALUE;
int radix = 10;
int fc = nextPrintable();
if (fc < '0') {
if (fc == '-') {
negative = true;
limit = Integer.MIN_VALUE;
} else if (fc != '+') {
throw new NumberFormatException();
}
fc = br.read();
}
int multmin = limit / radix;
int ch = fc;
do {
int digit = ch - '0';
if (digit < 0 || digit >= radix) {
throw new NumberFormatException();
}
if (res < multmin) {
throw new NumberFormatException();
}
res *= radix;
if (res < limit + digit) {
throw new NumberFormatException();
}
res -= digit;
} while (isPrintable(ch = br.read()));
return negative ? res : -res;
} catch (IOException e) {
throw new NoSuchElementException();
}
}
long nextLong() {
try {
// parseLong from Long.parseLong()
boolean negative = false;
long res = 0;
long limit = -Long.MAX_VALUE;
int radix = 10;
int fc = nextPrintable();
if (fc < '0') {
if (fc == '-') {
negative = true;
limit = Long.MIN_VALUE;
} else if (fc != '+') {
throw new NumberFormatException();
}
fc = br.read();
}
long multmin = limit / radix;
int ch = fc;
do {
int digit = ch - '0';
if (digit < 0 || digit >= radix) {
throw new NumberFormatException();
}
if (res < multmin) {
throw new NumberFormatException();
}
res *= radix;
if (res < limit + digit) {
throw new NumberFormatException();
}
res -= digit;
} while (isPrintable(ch = br.read()));
return negative ? res : -res;
} catch (IOException e) {
throw new NoSuchElementException();
}
}
float nextFloat() {
return Float.parseFloat(next());
}
double nextDouble() {
return Double.parseDouble(next());
}
String nextLine() {
try {
int ch;
while (isCRLF(ch = br.read())) {
if (ch == -1) {
throw new NoSuchElementException();
}
}
StringBuilder sb = new StringBuilder();
do {
sb.appendCodePoint(ch);
} while (!isCRLF(ch = br.read()));
return sb.toString();
} catch (IOException e) {
throw new NoSuchElementException();
}
}
int[] nextIntArray(int n) {
int[] ret = new int[n];
for (int i = 0; i < n; i++) {
ret[i] = sc.nextInt();
}
return ret;
}
long[] nextLongArray(int n) {
long[] ret = new long[n];
for (int i = 0; i < n; i++) {
ret[i] = sc.nextLong();
}
return ret;
}
int[][] nextIntArrays(int m, int n) {
int[][] ret = new int[m][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
ret[j][i] = sc.nextInt();
}
}
return ret;
}
void close() {
try {
br.close();
} catch (IOException e) {
// throw new NoSuchElementException();
}
}
}
static class Printer extends PrintWriter {
Printer(OutputStream out) {
super(out);
}
void printInts(int... a) {
StringBuilder sb = new StringBuilder(32);
for (int i = 0, size = a.length; i < size; i++) {
if (i > 0) {
sb.append(' ');
}
sb.append(a[i]);
}
println(sb);
}
void printLongs(long... a) {
StringBuilder sb = new StringBuilder(64);
for (int i = 0, size = a.length; i < size; i++) {
if (i > 0) {
sb.append(' ');
}
sb.append(a[i]);
}
println(sb);
}
void printStrings(String... a) {
StringBuilder sb = new StringBuilder(32);
for (int i = 0, size = a.length; i < size; i++) {
if (i > 0) {
sb.append(' ');
}
sb.append(a[i]);
}
println(sb);
}
}
}