Segment Tree
โครงสร้างข้อมูล Segment Tree สำหรับ Range Query
Segment Tree
Segment Tree เป็นโครงสร้างข้อมูลที่ใช้ตอบ range queries และ point/range updates ได้อย่างมีประสิทธิภาพ
เมื่อไหร่ควรใช้ Segment Tree?
- Range sum/min/max queries
- Range updates
- ต้องการทั้ง query และ update ใน
โครงสร้าง
Segment Tree เป็น binary tree ที่:
- แต่ละ node เก็บข้อมูลของ range
- Root เก็บข้อมูลของทั้ง array
- Leaf nodes คือแต่ละ element
[0-7]
/ \
[0-3] [4-7]
/ \ / \
[0-1] [2-3] [4-5] [6-7]
/ \ / \ / \ / \
[0][1][2][3][4][5][6][7]
Basic Implementation
Range Sum Query + Point Update
class SegmentTree {
private:
vector<long long> tree;
int n;
void build(vector<int>& arr, int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(arr, 2*node, start, mid);
build(arr, 2*node+1, mid+1, end);
tree[node] = tree[2*node] + tree[2*node+1];
}
}
void update(int node, int start, int end, int idx, int val) {
if (start == end) {
tree[node] = val;
} else {
int mid = (start + end) / 2;
if (idx <= mid) {
update(2*node, start, mid, idx, val);
} else {
update(2*node+1, mid+1, end, idx, val);
}
tree[node] = tree[2*node] + tree[2*node+1];
}
}
long long query(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0; // Out of range
}
if (l <= start && end <= r) {
return tree[node]; // Fully in range
}
int mid = (start + end) / 2;
return query(2*node, start, mid, l, r) +
query(2*node+1, mid+1, end, l, r);
}
public:
SegmentTree(vector<int>& arr) {
n = arr.size();
tree.resize(4 * n);
build(arr, 1, 0, n-1);
}
void update(int idx, int val) {
update(1, 0, n-1, idx, val);
}
long long query(int l, int r) {
return query(1, 0, n-1, l, r);
}
};
Range Minimum Query
class SegmentTreeMin {
private:
vector<int> tree;
int n;
const int INF = 1e9;
void build(vector<int>& arr, int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(arr, 2*node, start, mid);
build(arr, 2*node+1, mid+1, end);
tree[node] = min(tree[2*node], tree[2*node+1]);
}
}
int query(int node, int start, int end, int l, int r) {
if (r < start || end < l) return INF;
if (l <= start && end <= r) return tree[node];
int mid = (start + end) / 2;
return min(query(2*node, start, mid, l, r),
query(2*node+1, mid+1, end, l, r));
}
public:
SegmentTreeMin(vector<int>& arr) {
n = arr.size();
tree.resize(4 * n);
build(arr, 1, 0, n-1);
}
int query(int l, int r) {
return query(1, 0, n-1, l, r);
}
};
Lazy Propagation
สำหรับ Range Updates ใช้ Lazy Propagation เพื่อ defer updates
Range Add + Range Sum
class SegmentTreeLazy {
private:
vector<long long> tree, lazy;
int n;
void push(int node, int start, int end) {
if (lazy[node] != 0) {
tree[node] += lazy[node] * (end - start + 1);
if (start != end) {
lazy[2*node] += lazy[node];
lazy[2*node+1] += lazy[node];
}
lazy[node] = 0;
}
}
void update(int node, int start, int end, int l, int r, long long val) {
push(node, start, end);
if (r < start || end < l) return;
if (l <= start && end <= r) {
lazy[node] += val;
push(node, start, end);
return;
}
int mid = (start + end) / 2;
update(2*node, start, mid, l, r, val);
update(2*node+1, mid+1, end, l, r, val);
tree[node] = tree[2*node] + tree[2*node+1];
}
long long query(int node, int start, int end, int l, int r) {
push(node, start, end);
if (r < start || end < l) return 0;
if (l <= start && end <= r) return tree[node];
int mid = (start + end) / 2;
return query(2*node, start, mid, l, r) +
query(2*node+1, mid+1, end, l, r);
}
public:
SegmentTreeLazy(int size) {
n = size;
tree.assign(4 * n, 0);
lazy.assign(4 * n, 0);
}
SegmentTreeLazy(vector<int>& arr) {
n = arr.size();
tree.resize(4 * n);
lazy.assign(4 * n, 0);
build(arr, 1, 0, n-1);
}
void build(vector<int>& arr, int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = (start + end) / 2;
build(arr, 2*node, start, mid);
build(arr, 2*node+1, mid+1, end);
tree[node] = tree[2*node] + tree[2*node+1];
}
}
void rangeAdd(int l, int r, long long val) {
update(1, 0, n-1, l, r, val);
}
long long rangeSum(int l, int r) {
return query(1, 0, n-1, l, r);
}
};
Range Set + Range Sum
class SegmentTreeRangeSet {
private:
vector<long long> tree;
vector<long long> lazy;
vector<bool> hasLazy;
int n;
void push(int node, int start, int end) {
if (hasLazy[node]) {
tree[node] = lazy[node] * (end - start + 1);
if (start != end) {
lazy[2*node] = lazy[2*node+1] = lazy[node];
hasLazy[2*node] = hasLazy[2*node+1] = true;
}
hasLazy[node] = false;
}
}
void update(int node, int start, int end, int l, int r, long long val) {
push(node, start, end);
if (r < start || end < l) return;
if (l <= start && end <= r) {
lazy[node] = val;
hasLazy[node] = true;
push(node, start, end);
return;
}
int mid = (start + end) / 2;
update(2*node, start, mid, l, r, val);
update(2*node+1, mid+1, end, l, r, val);
tree[node] = tree[2*node] + tree[2*node+1];
}
long long query(int node, int start, int end, int l, int r) {
push(node, start, end);
if (r < start || end < l) return 0;
if (l <= start && end <= r) return tree[node];
int mid = (start + end) / 2;
return query(2*node, start, mid, l, r) +
query(2*node+1, mid+1, end, l, r);
}
public:
SegmentTreeRangeSet(int size) {
n = size;
tree.assign(4 * n, 0);
lazy.assign(4 * n, 0);
hasLazy.assign(4 * n, false);
}
void rangeSet(int l, int r, long long val) {
update(1, 0, n-1, l, r, val);
}
long long rangeSum(int l, int r) {
return query(1, 0, n-1, l, r);
}
};
Iterative Segment Tree
Version ที่ไม่ใช้ recursion (เร็วกว่า)
class IterativeSegTree {
private:
vector<long long> tree;
int n;
public:
IterativeSegTree(int size) {
n = size;
tree.assign(2 * n, 0);
}
IterativeSegTree(vector<int>& arr) {
n = arr.size();
tree.resize(2 * n);
// Copy array to leaves
for (int i = 0; i < n; i++) {
tree[n + i] = arr[i];
}
// Build tree
for (int i = n - 1; i > 0; i--) {
tree[i] = tree[2*i] + tree[2*i+1];
}
}
void update(int idx, long long val) {
idx += n;
tree[idx] = val;
while (idx > 1) {
idx /= 2;
tree[idx] = tree[2*idx] + tree[2*idx+1];
}
}
long long query(int l, int r) { // [l, r]
long long sum = 0;
l += n; r += n + 1;
while (l < r) {
if (l & 1) sum += tree[l++];
if (r & 1) sum += tree[--r];
l /= 2; r /= 2;
}
return sum;
}
};
Complexity
| Operation | Time |
|---|---|
| Build | |
| Point Update | |
| Range Update (Lazy) | |
| Query | |
| Space |
ตัวอย่างการใช้งาน
int main() {
vector<int> arr = {1, 3, 5, 7, 9, 11};
SegmentTree st(arr);
// Query sum [1, 3]
cout << st.query(1, 3) << endl; // 3 + 5 + 7 = 15
// Update index 2 to 10
st.update(2, 10);
// Query sum [1, 3] again
cout << st.query(1, 3) << endl; // 3 + 10 + 7 = 20
}
เลือกใช้อะไร?
| สถานการณ์ | ใช้ |
|---|---|
| Point update + Range query | Basic Segment Tree |
| Range update + Range query | Segment Tree + Lazy |
| เน้น performance | Iterative Segment Tree |
| Simple range sum | Fenwick Tree (ง่ายกว่า) |
💡 Tip: ถ้าต้องการแค่ range sum และ point update ให้ใช้ Fenwick Tree แทน จะง่ายกว่า