Summary
In this post, I will introduce several templates for the Segment Tree. Compared to the binary index tree, segment tree is easy to understand, though the codes are longer. I will not introduce 2D Segment Tree in this post.
Some OJ problems are listed below.
Details
LC 307
class SegmentTree {
private:
vector<int> nums;
vector<int> info;
int n;
void build(int left, int right, int root) {
if (left > right) {
return;
} else if (left == right) {
info[root] = nums[left];
} else {
int mid = left + (right - left) / 2;
build(left, mid, root << 1);
build(mid + 1, right, (root << 1) + 1);
info[root] = max(info[root << 1], info[(root << 1) + 1]);
}
}
void update(int left, int right, int root, int q_index, int val) {
if (q_index < left || q_index > right) {
return;
} else if (left == right) {
info[root] += val;
} else {
int mid = left + (right - left) / 2;
update(left, mid, root << 1, q_index, val);
update(mid + 1, right, (root << 1) + 1, q_index, val);
info[root] = max(info[root << 1], info[(root << 1) + 1]);
}
}
int query(int left, int right, int root, int q_left, int q_right) {
if (q_left > right || q_right < left) {
return 0;
} else if (q_left <= left && right <= q_right) {
return info[root];
} else {
int mid = left + (right - left) / 2;
return max(query(left, mid, root << 1, q_left, q_right),
query(mid + 1, right, (root << 1) + 1, q_left, q_right));
}
}
public:
SegmentTree(vector<int> const &nums_) : nums(nums_) {
n = nums_.size();
info.resize(4 * n);
build(0, n - 1, 1);
}
void Update(int q_index, int val) {
update(0, n - 1, 1, q_index, val);
}
int Query(int q_left, int q_right) {
return query(0, n - 1, 1, q_left, q_right);
}
};
Query Range Sum & Update Range Nodes
class SegmentTree {
private:
vector<int> info;
vector<int> tag;
vector<int> nums;
int n;
void build(int root, int left, int right) {
if (left > right) {
return;
} else if (left == right) {
info[root] += nums[left];
} else {
int mid = left + (right - left) / 2;
build(root << 1, left, mid);
build((root << 1) + 1, mid + 1, right);
info[root] = info[root << 1] + info[(root << 1) + 1];
}
}
void push_down(int root, int left, int right) {
int mid = left + (right - left) / 2;
info[root << 1] += (mid - left + 1) * tag[root];
tag[root << 1] += tag[root];
info[(root << 1) + 1] += (right - mid) * tag[root];
tag[(root << 1) + 1] += tag[root];
tag[root] = 0;
}
void pull_up(int root) {
info[root] = info[root << 1] + info[(root << 1) + 1];
}
void update(int root, int left, int right, int q_left, int q_right, int val) {
if (right < q_left || left > q_right) {
return;
} else if (q_left <= left && right <= q_right) {
info[root] += (right - left + 1) * val;
tag[root] += val;
} else {
int mid = left + (right - left) / 2;
push_down(root, left, right);
update(root << 1, left, mid, q_left, q_right, val);
update((root << 1) + 1, mid + 1, right, q_left, q_right, val);
pull_up(root);
}
}
int query(int root, int left, int right, int q_left, int q_right) {
if (right < q_left || left > q_right) {
return 0;
} if (q_left <= left && right <= q_right) {
return info[root];
} else {
int mid = left + (right - left) / 2;
push_down(root, left, right);
return query(root << 1, left, mid, q_left, q_right)
+ query((root << 1) + 1, mid + 1, right, q_left, q_right);
}
}
public:
SegmentTree(vector<int> const& nums) {
this->nums = nums;
n = nums.size();
info.resize(4 * n);
tag.resize(4 * n);
build(1, 0, n - 1);
}
void Update(int q_left, int q_right, int val) {
update(1, 0, n - 1, q_left, q_right, val);
}
int Query(int q_left, int q_right) {
return query(1 ,0, n - 1, q_left, q_right);
}
};
LC 1157
/*
Pre-calculation: O(n * log(n))
Query: O(log(n) * log(n))
Single Node Update: O(log(n) * log(n))
*/
inline int left_child(int root) {
return root << 1;
}
inline int right_child(int root) {
return (root << 1) + 1;
}
class MajorityChecker {
private:
struct Info {
int freq = -1;
int num = -1;
Info() {}
Info(int freq_, int num_) : freq(freq_), num(num_) {}
string ToString() {return to_string(num) + "," + to_string(freq);}
};
int n;
vector<Info> infos;
vector<int> nums;
unordered_map<int, vector<int>> num_to_posi;
int get_freq(int num, int left, int right) {
auto it1 = lower_bound(num_to_posi[num].begin(), num_to_posi[num].end(), left);
auto it2 = upper_bound(num_to_posi[num].begin(), num_to_posi[num].end(), right);
return distance(it1, it2);
}
void build(int root, int left, int right) {
if (left > right) {
return;
} else if (left == right) {
infos[root] = {1, nums[left]};
} else {
int mid = left + (right - left) / 2;
build(left_child(root), left, mid);
build(right_child(root), mid + 1, right);
// 1 2 1 1
if (infos[left_child(root)].num == infos[right_child(root)].num && infos[left_child(root)].num != -1) {
infos[root] = {infos[left_child(root)].freq + infos[right_child(root)].freq, infos[left_child(root)].num};
} else if (infos[left_child(root)].num != -1 && get_freq(infos[left_child(root)].num, left, right) * 2 > (right - left + 1)) {
infos[root] = {get_freq(infos[left_child(root)].num, left, right), infos[left_child(root)].num};
} else if (infos[right_child(root)].num != -1 && get_freq(infos[right_child(root)].num, left, right) * 2 > (right - left + 1)) {
infos[root] = {get_freq(infos[right_child(root)].num, left, right), infos[right_child(root)].num};
} else {
infos[root] = {-1, -1};
}
}
}
Info query(int root, int left, int right, int q_left, int q_right, int thresh) {
if (left > q_right || right < q_left) {
return {-1, -1};
} else if (q_left <= left && right <= q_right) {
return infos[root];
} else {
int mid = left + (right - left) / 2;
auto left_info = query(left_child(root), left, mid, q_left, q_right, thresh);
auto right_info = query(right_child(root), mid + 1, right, q_left, q_right, thresh);
// gather info in range [q_left, q_right]
if (left_info.num != -1 && (get_freq(left_info.num, q_left, q_right) * 2 > (q_right - q_left + 1))) {
return {get_freq(left_info.num, q_left, q_right), left_info.num};
} else if (right_info.num != -1 && (get_freq(right_info.num, q_left, q_right) * 2 > (q_right - q_left + 1))) {
return {get_freq(right_info.num, q_left, q_right), right_info.num};
} else {
return {-1, -1};
}
}
}
public:
void PrintTree() {
for (int i = 1; i < 4 * n; ++i) {
cout << infos[i].ToString() << ";";
}
cout << endl;
}
MajorityChecker(vector<int>& arr) {
n = arr.size();
infos.resize(n << 2);
nums = arr;
for (int i = 0; i < arr.size(); ++i) {
num_to_posi[arr[i]].push_back(i);
}
build(1, 0, n - 1);
}
int query(int left, int right, int threshold) {
auto info = query(1, 0, n - 1, left, right, threshold);
if (info.freq >= threshold) {
return info.num;
}
return -1;
}
};