Segment Tree

Summary

In this post, I will introduce several templates for the Segment Tree. Compared to the binary index tree, segment tree is easy to understand, though the codes are longer. I will not introduce 2D Segment Tree in this post.

Some OJ problems are listed below.

  1. LC 1157: query range majority; update single node;
  2. LC 307: query range sum; update single node;

Details

LC 307

class SegmentTree {
private:
    vector<int> nums;
    vector<int> info;
    int n;

    void build(int left, int right, int root) {
        if (left > right) {
            return;
        } else if (left == right) {
            info[root] = nums[left];
        } else {
            int mid = left + (right - left) / 2;
            build(left, mid, root << 1);
            build(mid + 1, right, (root << 1) + 1);
            info[root] = max(info[root << 1], info[(root << 1) + 1]);
        }
    }

    void update(int left, int right, int root, int q_index, int val) {
        if (q_index < left || q_index > right) {
            return;
        } else if (left == right) {
            info[root] += val;
        } else {
            int mid = left + (right - left) / 2;
            update(left, mid, root << 1, q_index, val);
            update(mid + 1, right, (root << 1) + 1, q_index, val);
            info[root] = max(info[root << 1], info[(root << 1) + 1]);
        }
    }

    int query(int left, int right, int root, int q_left, int q_right) {
        if (q_left > right || q_right < left) {
            return 0;
        } else if (q_left <= left && right <= q_right) {
            return info[root];
        } else {
            int mid = left + (right - left) / 2;
            return max(query(left, mid, root << 1, q_left, q_right),
                       query(mid + 1, right, (root << 1) + 1, q_left, q_right));
        }
    }
public:
    SegmentTree(vector<int> const &nums_) : nums(nums_) {
        n = nums_.size();
        info.resize(4 * n);
        build(0, n - 1, 1);
    }
    void Update(int q_index, int val) {
        update(0, n - 1, 1, q_index, val);
    }

    int Query(int q_left, int q_right) {
        return query(0, n - 1, 1, q_left, q_right);
    }
};

Query Range Sum & Update Range Nodes

class SegmentTree {
private:
    vector<int> info;
    vector<int> tag;

    vector<int> nums;
    int n;

    void build(int root, int left, int right) {
        if (left > right) {
            return;
        } else if (left == right) {
            info[root] += nums[left];
        } else {
            int mid = left + (right - left) / 2;
            build(root << 1, left, mid);
            build((root << 1) + 1, mid + 1, right);
            info[root] = info[root << 1] + info[(root << 1) + 1];
        }
    }

    void push_down(int root, int left, int right) {
        int mid = left + (right - left) / 2;
        info[root << 1] += (mid - left + 1) * tag[root];
        tag[root << 1] += tag[root];
        info[(root << 1) + 1] += (right - mid) * tag[root];
        tag[(root << 1) + 1] += tag[root];
        tag[root] = 0;
    }

    void pull_up(int root) {
        info[root] = info[root << 1] + info[(root << 1) + 1];
    }

    void update(int root, int left, int right, int q_left, int q_right, int val) {
        if (right < q_left || left > q_right) {
            return;
        } else if (q_left <= left && right <= q_right) {
            info[root] += (right - left + 1) * val;
            tag[root] += val;
        } else {
            int mid = left + (right - left) / 2;
            push_down(root, left, right);
            update(root << 1, left, mid, q_left, q_right, val);
            update((root << 1) + 1, mid + 1, right, q_left, q_right, val);
            pull_up(root);
        }
    }

    int query(int root, int left, int right, int q_left, int q_right) {
        if (right < q_left || left > q_right) {
            return 0;
        } if (q_left <= left && right <= q_right) {
            return info[root];
        } else {
            int mid = left + (right - left) / 2;
            push_down(root, left, right);
            return query(root << 1, left, mid, q_left, q_right)
                + query((root << 1) + 1, mid + 1, right, q_left, q_right);
        }
    }

public:
    SegmentTree(vector<int> const& nums) {
        this->nums = nums;
        n = nums.size();
        info.resize(4 * n);
        tag.resize(4 * n);
        build(1, 0, n - 1);
    }

    void Update(int q_left, int q_right, int val) {
        update(1, 0, n - 1, q_left, q_right, val);
    }

    int Query(int q_left, int q_right) {
        return query(1 ,0, n - 1, q_left, q_right);
    }
};

LC 1157

/*
Pre-calculation: O(n * log(n))
Query: O(log(n) * log(n))
Single Node Update: O(log(n) * log(n))
*/

inline int left_child(int root) {
    return root << 1;
}
inline int right_child(int root) {
    return (root << 1) + 1;
}

class MajorityChecker {
private:
    struct Info {
        int freq = -1;
        int num = -1;
        Info() {}
        Info(int freq_, int num_) : freq(freq_), num(num_) {}
        string ToString() {return to_string(num) + "," + to_string(freq);}
    };
    int n;
    vector<Info> infos;
    vector<int> nums;
    unordered_map<int, vector<int>> num_to_posi;
    int get_freq(int num, int left, int right) {
        auto it1 = lower_bound(num_to_posi[num].begin(), num_to_posi[num].end(), left);
        auto it2 = upper_bound(num_to_posi[num].begin(), num_to_posi[num].end(), right);
        return distance(it1, it2);
    }

    void build(int root, int left, int right) {
        if (left > right) {
            return; 
        } else if (left == right) {
            infos[root] = {1, nums[left]};
        } else {
            int mid = left + (right - left) / 2;
            build(left_child(root), left, mid);
            build(right_child(root), mid + 1, right);
            // 1 2 1 1
            if (infos[left_child(root)].num == infos[right_child(root)].num && infos[left_child(root)].num != -1) {
                infos[root] = {infos[left_child(root)].freq + infos[right_child(root)].freq, infos[left_child(root)].num};
            } else if (infos[left_child(root)].num != -1 && get_freq(infos[left_child(root)].num, left, right) * 2 > (right - left + 1)) {
                infos[root] = {get_freq(infos[left_child(root)].num, left, right), infos[left_child(root)].num};
            } else if (infos[right_child(root)].num != -1 && get_freq(infos[right_child(root)].num, left, right) * 2 > (right - left + 1)) {
                infos[root] = {get_freq(infos[right_child(root)].num, left, right), infos[right_child(root)].num};
            } else {
                infos[root] = {-1, -1};
            }
        }
    }

    Info query(int root, int left, int right, int q_left, int q_right, int thresh) {
        if (left > q_right || right < q_left) {
            return {-1, -1};
        } else if (q_left <= left && right <= q_right) {
            return infos[root];
        } else {
            int mid = left + (right - left) / 2;
            auto left_info = query(left_child(root), left, mid, q_left, q_right, thresh);
            auto right_info = query(right_child(root), mid + 1, right, q_left, q_right, thresh);
            // gather info in range [q_left, q_right]
            if (left_info.num != -1 && (get_freq(left_info.num, q_left, q_right) * 2 > (q_right - q_left + 1))) {
                return {get_freq(left_info.num, q_left, q_right), left_info.num};
            } else if (right_info.num != -1 && (get_freq(right_info.num, q_left, q_right) * 2 > (q_right - q_left + 1))) {
                return {get_freq(right_info.num, q_left, q_right), right_info.num};
            } else {
                return {-1, -1};
            }
        }
    }

public:
    void PrintTree() {
        for (int i = 1; i < 4 * n; ++i) {
            cout << infos[i].ToString() << ";";
        }
        cout << endl;
    }

    MajorityChecker(vector<int>& arr) {
        n = arr.size();
        infos.resize(n << 2);
        nums = arr;
        for (int i = 0; i < arr.size(); ++i) {
            num_to_posi[arr[i]].push_back(i);
        }
        build(1, 0, n - 1);
    }

    int query(int left, int right, int threshold) {
        auto info = query(1, 0, n - 1, left, right, threshold);
        if (info.freq >= threshold) {
            return info.num;
        }
        return -1;
    }
};

Leave a Reply

Your email address will not be published. Required fields are marked *