結果

問題 No.3078 Difference Sum Query
ユーザー Nauclhlt🪷
提出日時 2025-03-28 22:10:27
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,452 ms / 2,000 ms
コード長 21,312 bytes
コンパイル時間 2,916 ms
コンパイル使用メモリ 226,848 KB
実行使用メモリ 134,652 KB
最終ジャッジ日時 2025-03-28 22:10:58
合計ジャッジ時間 26,311 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 26
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;

using ll = long long;


#include <vector>
#include <functional>
#include <stdexcept>

#include <vector>
#include <functional>
#include <stdexcept>

template <typename T>
class PersistentSegmentTree {
private:
    struct Node {
        T data;
        Node* left;
        Node* right;

        Node(T data) : data(data), left(nullptr), right(nullptr) {}
    };

    int treeSize;
    int size;
    std::function<T(T, T)> op;
    std::function<T(T, T)> _apply;
    T identity;
    std::vector<Node*> snapshots;

    Node* buildClearRange(int l, int r, T value) {
        if (l + 1 >= r) return new Node(value);
        return mergeNode(buildClearRange(l, (l + r) / 2, value), buildClearRange((l + r) / 2, r, value));
    }

    Node* buildRange(int l, int r, const std::vector<T>& array) {
        if (l + 1 >= r) return new Node(array[l]);
        return mergeNode(buildRange(l, (l + r) / 2, array), buildRange((l + r) / 2, r, array));
    }

    Node* mergeNode(Node* l, Node* r) {
        Node* res = new Node(op(l->data, r->data));
        res->left = l;
        res->right = r;
        return res;
    }

    int registerNode(Node* node) {
        snapshots.push_back(node);
        return snapshots.size() - 1;
    }

    Node* getRootAt(int time) const {
        return snapshots[time];
    }

    Node* applyRec(int index, T value, Node* node, int l, int r) {
        if (r <= index || index + 1 <= l) {
            return node;
        } else if (index <= l && r <= index + 1) {
            return new Node(_apply(node->data, value));
        } else {
            return mergeNode(applyRec(index, value, node->left, l, (l + r) / 2), applyRec(index, value, node->right, (l + r) / 2, r));
        }
    }

    T queryRec(int left, int right, Node* node, int l, int r) const {
        if (r <= left || right <= l) {
            return identity;
        } else if (left <= l && r <= right) {
            return node->data;
        } else {
            return op(queryRec(left, right, node->left, l, (l + r) / 2), queryRec(left, right, node->right, (l + r) / 2, r));
        }
    }

public:
    PersistentSegmentTree(int n, std::function<T(T, T)> op, std::function<T(T, T)> _apply, T identity)
        : size(n), treeSize(2 * n - 1), op(op), _apply(_apply), identity(identity) {}

    int build(const std::vector<T>& array) {
        if (size != array.size()) {
            throw std::invalid_argument("Size of the specified array does not match with the data size passed in the constructor.");
        }
        return registerNode(buildRange(0, array.size(), array));
    }

    int buildClear(T value) {
        return registerNode(buildClearRange(0, size, value));
    }

    void clearSnapshots() {
        snapshots.clear();
    }

    int apply(int time, int index, T value) {
        return registerNode(applyRec(index, value, getRootAt(time), 0, size));
    }

    T query(int time, int left, int right) const {
        return queryRec(left, right, getRootAt(time), 0, size);
    }

    T getByIndex(int time, int index) const {
        Node* current = getRootAt(time);
        int l = 0, r = size;
        while (true) {
            if (l == index && l + 1 == r) return current->data;
            if (index < (l + r) / 2) {
                r = (l + r) / 2;
                current = current->left;
            } else {
                l = (l + r) / 2;
                current = current->right;
            }
        }
    }
};


template <typename T>
class AVLTree
{
private:
    class Node
    {
    private:
        T _value;
        shared_ptr<Node> _left;
        shared_ptr<Node> _right;
        int _bias;
        int _height;
        int _size;

    public:
        Node(T value)
        {
            _value = value;
            _left = shared_ptr<Node>();
            _right = shared_ptr<Node>();
            _bias = 0;
            _height = 1;
            _size = 1;
        }

        inline bool Has2Children()
        {
            return _left && _right;
        }

        inline bool HasOnlyLeft()
        {
            return _left && !_right;
        }

        inline bool HasOnlyRight()
        {
            return !_left && _right;
        }

        inline bool HasNoChild()
        {
            return !_left && !_right;
        }

        inline bool HasRight()
        {
            return (bool)_right;
        }

        inline bool HasLeft()
        {
            return (bool)_left;
        }

        inline void SetValue(T value)
        {
            _value = value;
        }

        inline T GetValue()
        {
            return _value;
        }

        inline void SetLeft(shared_ptr<Node> left)
        {
            _left = left;
        }

        inline shared_ptr<Node> GetLeft()
        {
            return _left;
        }

        inline void SetRight(shared_ptr<Node> right)
        {
            _right = right;
        }

        inline shared_ptr<Node> GetRight()
        {
            return _right;
        }

        inline void SetBias(int bias)
        {
            _bias = bias;
        }

        inline int GetBias()
        {
            return _bias;
        }

        inline void SetHeight(int height)
        {
            _height = height;
        }

        inline int GetHeight()
        {
            return _height;
        }

        inline void SetSize(int size)
        {
            _size = size;
        }

        inline int GetSize()
        {
            return _size;
        }

        inline int LeftHeight()
        {
            return _left ? _left->GetHeight() : 0;
        }

        inline int RightHeight()
        {
            return _right ? _right->GetHeight() : 0;
        }

        inline int LeftSize()
        {
            return _left ? _left->GetSize() : 0;
        }

        inline int RightSize()
        {
            return _right ? _right->GetSize() : 0;
        }
    };

    shared_ptr<Node> _rootNode;


public:
    AVLTree()
    {
        _rootNode = shared_ptr<Node>();
    }

    int Count()
    {
        return SizeOf(_rootNode);
    }

    void Add(T value)
    {
        _rootNode = AddRecursive(_rootNode, value);
    }

    void Remove(T value)
    {
        _rootNode = RemoveRecursive(_rootNode, value);
    }

    void _PrintTree()
    {
        auto printNode = [&](shared_ptr<Node> node, int depth, auto self)
        {
            if (!node)
                return;
            self(node->GetRight(), depth + 1, self);
            for (int i = 0; i < depth; i++)
            {
                cout << "\t";
            }
            cout << node->GetValue() << endl;
            self(node->GetLeft(), depth + 1, self);
        };

        printNode(_rootNode, 0, printNode);
    }

private:
    shared_ptr<Node> RemoveRecursive(shared_ptr<Node> current, T value)
    {
        if (!current)
        {
            return shared_ptr<Node>();
        }

        if (current->GetValue() == value)
        {
            return InternalRemoveNode(current);
        }

        if (value < current->GetValue())
        {
            current->SetLeft(RemoveRecursive(current->GetLeft(), value));

            Update(current->GetLeft());
            Update(current);

            current = Balance(current);
        }
        else
        {
            current->SetRight(RemoveRecursive(current->GetRight(), value));

            Update(current->GetRight());
            Update(current);

            current = Balance(current);
        }

        return current;
    }

    shared_ptr<Node> InternalRemoveNode(shared_ptr<Node> target)
    {
        if (target->Has2Children())
        {
            shared_ptr<Node> max = GetMaxNode(target);
            T val = max->GetValue();

            if (target->GetLeft() == max)
            {
                target->SetLeft(target->GetLeft()->GetLeft());
            }
            else
            {
                target->SetLeft(DeleteRightNode(target->GetLeft(), max));
            }

            target->SetValue(val);

            Update(target->GetLeft());
            Update(target->GetLeft());

            Update(target);
            target = Balance(target);

            return target;
        }
        else if (target->HasOnlyLeft())
        {
            target = target->GetLeft();

            Update(target->GetLeft());
            Update(target->GetRight());
            Update(target);

            target = Balance(target);

            return target;
        }
        else if (target->HasOnlyRight())
        {
            target = target->GetRight();

            Update(target->GetLeft());
            Update(target->GetRight());
            Update(target);

            target = Balance(target);

            return target;
        }
        else
        {
            return shared_ptr<Node>();
        }
    }

    shared_ptr<Node> AddRecursive(shared_ptr<Node> current, T value)
    {
        if (!current)
        {
            current = shared_ptr<Node>(new Node(value));
            
            Update(current);

            return current;
        }

        if (value < current->GetValue())
        {
            current->SetLeft(AddRecursive(current->GetLeft(), value));

            Update(current->GetLeft());
            Update(current);

            current = Balance(current);
        }
        else
        {
            current->SetRight(AddRecursive(current->GetRight(), value));

            Update(current->GetRight());
            Update(current);

            current = Balance(current);
        }

        return current;
    }

    T GetByIndexRecursive(shared_ptr<Node> current, int offset)
    {
        int left = current->LeftSize();
        if (left == offset)
        {
            return current->GetValue();
        }
        if (offset < left)
        {
            return GetByIndexRecursive(current->GetLeft(), offset);
        }
        else
        {
            return GetByIndexRecursive(current->GetRight(), offset - left - 1);
        }
    }

    shared_ptr<Node> Balance(shared_ptr<Node> node)
    {
        int bias = node->GetBias();

        if (bias == 0)
        {
            return node;
        }

        if (bias == 1 || bias == -1)
        {
            return node;
        }

        if (bias >= 2)
        {
            if (node->GetLeft()->GetBias() > 0)
            {
                node = RotateRight(node);
                node->SetBias(0);
                return node;
            }
            else
            {
                node->SetLeft(RotateLeft(node->GetLeft()));
                node = RotateRight(node);
                node->SetBias(0);
                return node;
            }
        }
        else
        {
            if (node->GetRight()->GetBias() < 0)
            {
                node = RotateLeft(node);
                node->SetBias(0);

                return node;
            }
            else
            {
                node->SetRight(RotateRight(node->GetRight()));
                node = RotateLeft(node);
                node->SetBias(0);

                return node;
            }
        }
    }

    shared_ptr<Node> DeleteRightNode(shared_ptr<Node> root, shared_ptr<Node> target)
    {
        if (!root)
            return shared_ptr<Node>();

        if (root->GetRight() == target)
        {
            root->SetRight(root->GetRight()->GetLeft());
            Update(root->GetRight());
            Update(root);

            root = Balance(root);

            return root;
        }
        else
        {
            root->SetRight(DeleteRightNode(root->GetRight(), target));

            Update(root->GetRight());
            Update(root->GetLeft());
            Update(root);

            root = Balance(root);

            return root;
        }
    }

    shared_ptr<Node> GetMaxNode(shared_ptr<Node> node)
    {
        shared_ptr<Node> cur = node;
        while (cur->HasRight())
        {
            cur = cur->GetRight();
        }

        return cur;
    }

    shared_ptr<Node> GetMinNode(shared_ptr<Node> node)
    {
        shared_ptr<Node> cur = node;
        while (cur->HasLeft())
        {
            cur = cur->GetLeft();
        }

        return cur;
    }

    shared_ptr<Node> RotateLeft(shared_ptr<Node> node)
    {
        shared_ptr<Node> right = node->GetRight();
        node->SetRight(right->GetLeft());
        right->SetLeft(node);

        Update(right->GetLeft());
        Update(right->GetRight());
        Update(right);

        return right;
    }

    shared_ptr<Node> RotateRight(shared_ptr<Node> node)
    {
        shared_ptr<Node> left = node->GetLeft();
        node->SetLeft(left->GetRight());
        left->SetRight(node);

        Update(left->GetLeft());
        Update(left->GetRight());
        Update(left);

        return left;
    }

    void Update(shared_ptr<Node> node)
    {
        if (!node)
            return;

        node->SetHeight(HeightOf(node));
        node->SetSize(SizeOf(node));
        node->SetBias(node->LeftHeight() - node->RightHeight());
    }

    int HeightOf(shared_ptr<Node> node)
    {
        if (!node)
            return 0;

        int left = node->LeftHeight();
        int right = node->RightHeight();

        return max(left, right) + 1;
    }

    int SizeOf(shared_ptr<Node> node)
    {
        if (!node)
            return 0;

        int left = node->LeftSize();
        int right = node->RightSize();

        return left + right + 1;
    }

public:
    bool Contains(T value)
    {
        shared_ptr<Node> current = _rootNode;

        while (current)
        {
            if (current->GetValue() == value)
            {
                return true;
            }

            if (value < current->GetValue())
            {
                current = current->GetLeft();
            }
            else
            {
                current = current->GetRight();
            }
        }

        return false;
    }

    T Max()
    {
        return GetMaxNode(_rootNode)->GetValue();
    }

    T Min()
    {
        return GetMinNode(_rootNode)->GetValue();
    }

    T GetByIndex(int index)
    {
        if (!_rootNode)
        {
            throw out_of_range("The specified index is out of range.");
        }

        if (index < 0 || index >= Count())
        {
            throw out_of_range("The specified index is out of range.");
        }

        return GetByIndexRecursive(_rootNode, index);
    }

    int IndexOf(T value)
    {
        if (!_rootNode)
        {
            return -1;
        }

        int index = _rootNode->LeftSize();
        shared_ptr<Node> current = _rootNode;

        while (true)
        {
            if (value < current->GetValue())
            {
                if (!current->HasLeft())
                {
                    return -1;
                }
                else
                {
                    current = current->GetLeft();
                    index -= current->RightSize() + 1;
                }
            }
            else if (value == current->GetValue())
            {
                return index;
            }
            else
            {
                if (!current->HasRight())
                {
                    return -1;
                }
                else
                {
                    current = current->GetRight();
                    index += current->LeftSize() + 1;
                }
            }
        }
    }

    int LowerBound(T value)
    {
        if (!_rootNode)
        {
            return 0;
        }

        int res = _rootNode->GetSize();
        shared_ptr<Node> current = _rootNode;
        int index = _rootNode->LeftSize();

        while (true)
        {
            if (value <= current->GetValue())
            {
                res = min(res, index);
                if (!current->HasLeft())
                {
                    break;
                }
                index -= current->GetLeft()->RightSize() + 1;
                current = current->GetLeft();
            }
            else
            {
                if (!current->HasRight())
                {
                    break;
                }
                index += current->GetRight()->LeftSize() + 1;
                current = current->GetRight();
            }
        }

        return res;
    }

    T LowerBoundValue(T value, T fallback)
    {
        if (!_rootNode)
        {
            return fallback;
        }

        int res = _rootNode->GetSize();
        shared_ptr<Node> current = _rootNode;
        int index = _rootNode->LeftSize();
        T lowerbound = fallback;

        while (true)
        {
            if (value <= current->GetValue())
            {
                res = min(res, index);
                lowerbound = current->GetValue();
                if (!current->HasLeft())
                {
                    break;
                }
                index -= current->GetLeft()->RightSize() + 1;
                current = current->GetLeft();
            }
            else
            {
                if (!current->HasRight())
                {
                    break;
                }
                index += current->GetRight()->LeftSize() + 1;
                current = current->GetRight();
            }
        }

        return res < Count() ? lowerbound : fallback;
    }

    vector<T> OrderAscending()
    {
        if (!_rootNode)
        {
            return vector<T>();
        }

        vector<T> res;
        res.reserve(Count());

        auto extract = [&](shared_ptr<Node> node, auto self)
        {
            if (!node)
                return;
            self(node->GetLeft(), self);
            res.push_back(node->GetValue());
            self(node->GetRight(), self);
        };

        return res;
    }

    vector<T> OrderDescending()
    {
        if (!_rootNode)
        {
            return vector<T>();
        }

        vector<T> res;
        res.reserve(Count());

        auto extract = [&](shared_ptr<Node> node, auto self)
        {
            if (!node)
                return;
            self(node->GetRight(), self);
            res.push_back(node->GetValue());
            self(node->GetLeft(), self);
        };

        return res;
    }
};

template <typename T>
class NauclhltSet
{
private:
    AVLTree<T> _tree;

public:
    inline int Count()
    {
        return _tree.Count();
    }

    inline T Max()
    {
        return _tree.Max();
    }

    inline T Min()
    {
        return _tree.Min();
    }

    inline void Add(T item)
    {
        _tree.Add(item);
    }

    inline void Remove(T item)
    {
        _tree.Remove(item);
    }

    inline bool Contains(T item)
    {
        return _tree.Contains(item);
    }

    inline int IndexOf(T item)
    {
        return _tree.IndexOf(item);
    }

    inline int LowerBound(T value)
    {
        return _tree.LowerBound(value);
    }

    inline T LowerBoundValue(T value, T fallback)
    {
        return _tree.LowerBoundValue(value, fallback);
    }

    inline T GetByIndex(int index)
    {
        return _tree.GetByIndex(index);
    }

    inline vector<T> OrderAscending()
    {
        return _tree.OrderAscending();
    }

    inline vector<T> OrderDescending()
    {
        return _tree.OrderDescending();
    }

    const T operator[](size_t index) const
    {
        return GetByIndex((int)index);
    }

    T operator[](size_t index)
    {
        return GetByIndex((int)index);
    }

    inline void _DebugPrintTree()
    {
        _tree._PrintTree();
    }
};


int main()
{
    cin.tie(nullptr);
    ios::sync_with_stdio(false);

    int N, Q;
    cin >> N >> Q;
    vector<ll> A(N);
    for (int i = 0; i < N; i++)
    {
        cin >> A[i];
    }

    vector<ll> seq;
    set<ll> set;
    for (int i = 0; i < N; i++)
    {
        seq.push_back(A[i]);
        set.insert(A[i]);
    }
    seq.erase(unique(seq.begin(), seq.end()), seq.end());
    sort(seq.begin(), seq.end());

    int max = seq.size();

    PersistentSegmentTree<ll> seg(max, [&](ll x, ll y) { return x + y; }, [&](ll x, ll a){return x + a; }, 0L);
    PersistentSegmentTree<ll> countseg(max, [&](ll x, ll y) { return x + y; }, [&](ll x, ll a){return x + a; }, 0L);

    vector<int> prefix(N + 1);
    prefix[0] = seg.buildClear(0L);
    countseg.buildClear(0L);

    for (int i = 1; i <= N; i++)
    {
        int idx = lower_bound(seq.begin(), seq.end(), A[i - 1]) - seq.begin();
        prefix[i] = seg.apply(prefix[i - 1], idx, A[i - 1]);
        countseg.apply(prefix[i - 1], idx, 1);
    }

    for (int i = 0; i < Q; i++)
    {
        int l, r;
        ll x;
        cin >> l >> r >> x;
        l--;

        int bound = lower_bound(seq.begin(), seq.end(), x) - seq.begin();
        if (set.find(x) != set.end()) bound++;

        ll downsum = seg.query(prefix[r], 0, bound) - seg.query(prefix[l], 0, bound);
        ll downcount = countseg.query(prefix[r], 0, bound) - countseg.query(prefix[l], 0, bound);
        ll upsum = seg.query(prefix[r], bound, max) - seg.query(prefix[l], bound, max);
        ll upcount = countseg.query(prefix[r], bound, max) - countseg.query(prefix[l], bound, max);

        ll ans = x * downcount - downsum + upsum - upcount * x;

        cout << ans << "\n";
    }
}
0