結果

問題 No.925 紲星 Extra
ユーザー akakimidoriakakimidori
提出日時 2020-02-19 20:41:56
言語 Rust
(1.77.0)
結果
AC  
実行時間 3,927 ms / 10,000 ms
コード長 8,143 bytes
コンパイル時間 1,117 ms
コンパイル使用メモリ 178,232 KB
実行使用メモリ 76,288 KB
最終ジャッジ日時 2024-04-17 06:07:43
合計ジャッジ時間 55,317 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
5,248 KB
testcase_01 AC 1 ms
5,376 KB
testcase_02 AC 1 ms
5,376 KB
testcase_03 AC 11 ms
5,376 KB
testcase_04 AC 11 ms
5,376 KB
testcase_05 AC 3,049 ms
75,008 KB
testcase_06 AC 3,125 ms
75,776 KB
testcase_07 AC 3,109 ms
75,264 KB
testcase_08 AC 3,353 ms
75,776 KB
testcase_09 AC 3,137 ms
75,520 KB
testcase_10 AC 3,412 ms
74,752 KB
testcase_11 AC 3,027 ms
76,032 KB
testcase_12 AC 2,823 ms
75,520 KB
testcase_13 AC 2,501 ms
75,264 KB
testcase_14 AC 3,174 ms
74,752 KB
testcase_15 AC 2,263 ms
75,648 KB
testcase_16 AC 3,608 ms
75,392 KB
testcase_17 AC 3,732 ms
74,496 KB
testcase_18 AC 3,323 ms
76,032 KB
testcase_19 AC 3,927 ms
75,776 KB
testcase_20 AC 2,736 ms
76,160 KB
testcase_21 AC 1,690 ms
76,288 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

type Link = Option<Box<Node>>;

struct Node {
    val: u64,
    sum: u64,
    size: usize,
    rank: usize,
    left: Link,
    right: Link,
}

fn get_size(node: &Link) -> usize {
    node.as_ref().map_or(0, |t| t.get_size())
}

fn get_sum(node: &Link) -> u64 {
    node.as_ref().map_or(0, |t| t.get_sum())
}

fn get_rank(node: &Link) -> usize {
    node.as_ref().map_or(0, |t| t.get_rank())
}

fn insert(node: &mut Link, val: u64) {
    if node.is_none() {
        *node = Some(Box::new(Node {
            val: val,
            sum: val,
            size: 1,
            rank: 1,
            left: None,
            right: None,
        }));
        return;
    }
    let node = node.as_mut().unwrap();
    if node.val <= val {
        insert(&mut node.right, val);
    } else {
        insert(&mut node.left, val);
    }
    node.update();
    node.balance();
}

fn merge(l: Link, r: Link) -> Link {
    if l.is_none() {
        return r;
    }
    if r.is_none() {
        return l;
    }
    let l_size = get_size(&l);
    let r_size = get_size(&r);
    if l_size >= r_size {
        let mut l = l.unwrap();
        l.right = merge(l.right.take(), r);
        l.update();
        l.balance();
        Some(l)
    } else {
        let mut r = r.unwrap();
        r.left = merge(l, r.left.take());
        r.update();
        r.balance();
        Some(r)
    }
}

fn remove(node: &mut Link, val: u64) {
    assert!(node.is_some());
    let t = node.as_mut().unwrap();
    if t.val == val {
        let l = t.left.take();
        let r = t.right.take();
        *node = merge(l, r);
    } else if t.val < val {
        remove(&mut t.right, val);
        t.update();
        t.balance();
    } else {
        remove(&mut t.left, val);
        t.update();
        t.balance();
    }
}

// x >= val となるnodeの数, xの和
fn eval(node: &Link, val: u64) -> (usize, u64) {
    if node.is_none() {
        return (0, 0);
    }
    let node = node.as_ref().unwrap();
    if node.val >= val {
        let size = get_size(&node.right);
        let sum = get_sum(&node.right);
        let res = eval(&node.left, val);
        (res.0 + size + 1, res.1 + sum + node.val)
    } else {
        eval(&node.right, val)
    }
}

#[allow(dead_code)]
fn walk(node: &Link) {
    if node.is_none() {
        return;
    }
    let node = node.as_ref().unwrap();
    walk(&node.left);
    print!("{} ", node.val);
    walk(&node.right);
}

impl Node {
    fn get_size(&self) -> usize {
        self.size
    }
    fn get_sum(&self) -> u64 {
        self.sum
    }
    fn get_rank(&self) -> usize {
        self.rank
    }
    fn update(&mut self) {
        self.size = 1 + get_size(&self.left) + get_size(&self.right);
        self.sum = self.val + get_sum(&self.left) + get_sum(&self.right);
        self.rank = std::cmp::max(get_rank(&self.left), get_rank(&self.right)) + 1;
    }
    fn rotate_left(&mut self) {
        let mut r = self.right.take().unwrap();
        self.right = r.left.take();
        self.update();
        std::mem::swap(self, &mut r);
        self.left = Some(r);
        self.update();
    }
    fn rotate_right(&mut self) {
        let mut l = self.left.take().unwrap();
        self.left = l.right.take();
        self.update();
        std::mem::swap(self, &mut l);
        self.right = Some(l);
        self.update();
    }
    fn get_bias(&self, b: usize) -> i32 {
        let l_size = get_rank(&self.left);
        let r_size = get_rank(&self.right);
        match (l_size <= b + r_size, b + l_size >= r_size) {
            (true, true) => 0,
            (false, _) => 1,
            (_, false) => -1,
        }
    }
    fn balance(&mut self) {
        let c = self.get_bias(1);
        if c == 0 {
            return;
        }
        if c == 1 {
            let l = self.left.as_mut().unwrap();
            if l.get_bias(0) == -1 {
                l.rotate_left();
            }
            self.rotate_right();
        } else {
            let r = self.right.as_mut().unwrap();
            if r.get_bias(0) == 1 {
                r.rotate_right();
            }
            self.rotate_left();
        }
    }
}

use std::io::Write;
use std::io::Read;

fn run() {
    let mut s = String::new();
    std::io::stdin().read_to_string(&mut s).unwrap();
    let mut it = s.trim().split_whitespace();
    let n: usize = it.next().unwrap().parse().unwrap();
    let q: usize = it.next().unwrap().parse().unwrap();
    let mut a: Vec<u64> = (0..n).map(|_| it.next().unwrap().parse().unwrap()).collect();
    let size = n.next_power_of_two();
    let mut tree: Vec<_> = (0..(2 * size)).map(|_| None).collect();
    for (i, &a) in a.iter().enumerate() {
        let mut x = i + size;
        while x > 0 {
            insert(&mut tree[x], a);
            x >>= 1;
        }
    }
    let mut sum = vec![0; 2 * size];
    for (sum, a) in sum[size..].iter_mut().zip(a.iter()) {
        *sum = *a;
    }
    for i in (1..size).rev() {
        sum[i] = sum[2 * i] + sum[2 * i + 1];
    }
    let mut xor = 0u64;
    let out = std::io::stdout();
    let mut out = std::io::BufWriter::new(out.lock());
    for _ in 0..q {
        let op: usize = it.next().unwrap().parse().unwrap();
        if op == 1 {
            let x: usize = it.next().unwrap().parse().unwrap();
            let x = ((x ^ (xor as usize)) & ((1 << 16) - 1)) - 1;
            let y: u64 = it.next().unwrap().parse().unwrap();
            let y = (y ^ xor) & ((1u64 << 40) - 1);
            let val = a[x];
            a[x] = y;
            let mut k = x + size;
            while k > 0 {
                remove(&mut tree[k], val);
                insert(&mut tree[k], y);
                k >>= 1;
            }
            let mut k = x + size;
            sum[k] = y;
            k >>= 1;
            while k > 0 {
                sum[k] = sum[2 * k] + sum[2 * k + 1];
                k >>= 1;
            }
        } else {
            let l: usize = it.next().unwrap().parse().unwrap();
            let mut l = (l ^ (xor as usize)) & ((1 << 16) - 1);
            let r: usize = it.next().unwrap().parse().unwrap();
            let mut r = (r ^ (xor as usize)) & ((1 << 16) - 1);
            if l > r {
                std::mem::swap(&mut l, &mut r);
            }
            l -= 1;
            let add = |a: (usize, u64), b: (usize, u64)| (a.0 + b.0, a.1 + b.1);
            let find = |val: u64| -> (usize, u64) {
                let mut x = l + size;
                let mut y = r + size;
                let mut ans = (0, 0);
                while x < y {
                    if x & 1 == 1 {
                        ans = add(ans, eval(&tree[x], val));
                        x += 1;
                    }
                    if y & 1 == 1 {
                        y -= 1;
                        ans = add(ans, eval(&tree[y], val));
                    }
                    x >>= 1;
                    y >>= 1;
                }
                ans
            };
            let len = r - l;
            let mid = (len + 1) / 2;
            let mut ok = 0;
            let mut ng = 1u64 << 40;
            while ng - ok > 1 {
                let m = (ok + ng) / 2;
                if find(m).0 >= mid {
                    ok = m;
                } else {
                    ng = m;
                }
            }
            let res = find(ok);
            let mut total = 0;
            let mut x = l + size;
            let mut y = r + size;
            while x < y {
                if x & 1 == 1 {
                    total += sum[x];
                    x += 1;
                }
                if y & 1 == 1 {
                    y -= 1;
                    total += sum[y];
                }
                x >>= 1;
                y >>= 1;
            }
            let len = len as u64;
            let mid = mid as u64;
            let res = (res.0 as u64, res.1);
            let mut ans = 0;
            ans += res.1 - mid * ok - ok * (res.0 - mid);
            ans += ok * (len - mid) - (total - res.1 + ok * (res.0 - mid));
            writeln!(out, "{}", ans).ok();
            xor ^= ans;
        }
    }
}

fn main() {
    run();
}
0