ICPC Competitive Programming

Segment Tree

โครงสร้างข้อมูล Segment Tree สำหรับ Range Query

Segment Tree

Segment Tree เป็นโครงสร้างข้อมูลที่ใช้ตอบ range queries และ point/range updates ได้อย่างมีประสิทธิภาพ

เมื่อไหร่ควรใช้ Segment Tree?

  • Range sum/min/max queries
  • Range updates
  • ต้องการทั้ง query และ update ใน O(logn)O(\log n)

โครงสร้าง

Segment Tree เป็น binary tree ที่:

  • แต่ละ node เก็บข้อมูลของ range
  • Root เก็บข้อมูลของทั้ง array
  • Leaf nodes คือแต่ละ element
            [0-7]
           /     \
      [0-3]       [4-7]
      /   \       /   \
   [0-1] [2-3] [4-5] [6-7]
   /  \  /  \  /  \  /  \
  [0][1][2][3][4][5][6][7]

Basic Implementation

Range Sum Query + Point Update

class SegmentTree {
private:
    vector<long long> tree;
    int n;
    
    void build(vector<int>& arr, int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
        } else {
            int mid = (start + end) / 2;
            build(arr, 2*node, start, mid);
            build(arr, 2*node+1, mid+1, end);
            tree[node] = tree[2*node] + tree[2*node+1];
        }
    }
    
    void update(int node, int start, int end, int idx, int val) {
        if (start == end) {
            tree[node] = val;
        } else {
            int mid = (start + end) / 2;
            if (idx <= mid) {
                update(2*node, start, mid, idx, val);
            } else {
                update(2*node+1, mid+1, end, idx, val);
            }
            tree[node] = tree[2*node] + tree[2*node+1];
        }
    }
    
    long long query(int node, int start, int end, int l, int r) {
        if (r < start || end < l) {
            return 0; // Out of range
        }
        if (l <= start && end <= r) {
            return tree[node]; // Fully in range
        }
        int mid = (start + end) / 2;
        return query(2*node, start, mid, l, r) + 
               query(2*node+1, mid+1, end, l, r);
    }
    
public:
    SegmentTree(vector<int>& arr) {
        n = arr.size();
        tree.resize(4 * n);
        build(arr, 1, 0, n-1);
    }
    
    void update(int idx, int val) {
        update(1, 0, n-1, idx, val);
    }
    
    long long query(int l, int r) {
        return query(1, 0, n-1, l, r);
    }
};

Range Minimum Query

class SegmentTreeMin {
private:
    vector<int> tree;
    int n;
    const int INF = 1e9;
    
    void build(vector<int>& arr, int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
        } else {
            int mid = (start + end) / 2;
            build(arr, 2*node, start, mid);
            build(arr, 2*node+1, mid+1, end);
            tree[node] = min(tree[2*node], tree[2*node+1]);
        }
    }
    
    int query(int node, int start, int end, int l, int r) {
        if (r < start || end < l) return INF;
        if (l <= start && end <= r) return tree[node];
        
        int mid = (start + end) / 2;
        return min(query(2*node, start, mid, l, r),
                   query(2*node+1, mid+1, end, l, r));
    }
    
public:
    SegmentTreeMin(vector<int>& arr) {
        n = arr.size();
        tree.resize(4 * n);
        build(arr, 1, 0, n-1);
    }
    
    int query(int l, int r) {
        return query(1, 0, n-1, l, r);
    }
};

Lazy Propagation

สำหรับ Range Updates ใช้ Lazy Propagation เพื่อ defer updates

Range Add + Range Sum

class SegmentTreeLazy {
private:
    vector<long long> tree, lazy;
    int n;
    
    void push(int node, int start, int end) {
        if (lazy[node] != 0) {
            tree[node] += lazy[node] * (end - start + 1);
            if (start != end) {
                lazy[2*node] += lazy[node];
                lazy[2*node+1] += lazy[node];
            }
            lazy[node] = 0;
        }
    }
    
    void update(int node, int start, int end, int l, int r, long long val) {
        push(node, start, end);
        
        if (r < start || end < l) return;
        
        if (l <= start && end <= r) {
            lazy[node] += val;
            push(node, start, end);
            return;
        }
        
        int mid = (start + end) / 2;
        update(2*node, start, mid, l, r, val);
        update(2*node+1, mid+1, end, l, r, val);
        tree[node] = tree[2*node] + tree[2*node+1];
    }
    
    long long query(int node, int start, int end, int l, int r) {
        push(node, start, end);
        
        if (r < start || end < l) return 0;
        if (l <= start && end <= r) return tree[node];
        
        int mid = (start + end) / 2;
        return query(2*node, start, mid, l, r) + 
               query(2*node+1, mid+1, end, l, r);
    }
    
public:
    SegmentTreeLazy(int size) {
        n = size;
        tree.assign(4 * n, 0);
        lazy.assign(4 * n, 0);
    }
    
    SegmentTreeLazy(vector<int>& arr) {
        n = arr.size();
        tree.resize(4 * n);
        lazy.assign(4 * n, 0);
        build(arr, 1, 0, n-1);
    }
    
    void build(vector<int>& arr, int node, int start, int end) {
        if (start == end) {
            tree[node] = arr[start];
        } else {
            int mid = (start + end) / 2;
            build(arr, 2*node, start, mid);
            build(arr, 2*node+1, mid+1, end);
            tree[node] = tree[2*node] + tree[2*node+1];
        }
    }
    
    void rangeAdd(int l, int r, long long val) {
        update(1, 0, n-1, l, r, val);
    }
    
    long long rangeSum(int l, int r) {
        return query(1, 0, n-1, l, r);
    }
};

Range Set + Range Sum

class SegmentTreeRangeSet {
private:
    vector<long long> tree;
    vector<long long> lazy;
    vector<bool> hasLazy;
    int n;
    
    void push(int node, int start, int end) {
        if (hasLazy[node]) {
            tree[node] = lazy[node] * (end - start + 1);
            if (start != end) {
                lazy[2*node] = lazy[2*node+1] = lazy[node];
                hasLazy[2*node] = hasLazy[2*node+1] = true;
            }
            hasLazy[node] = false;
        }
    }
    
    void update(int node, int start, int end, int l, int r, long long val) {
        push(node, start, end);
        
        if (r < start || end < l) return;
        
        if (l <= start && end <= r) {
            lazy[node] = val;
            hasLazy[node] = true;
            push(node, start, end);
            return;
        }
        
        int mid = (start + end) / 2;
        update(2*node, start, mid, l, r, val);
        update(2*node+1, mid+1, end, l, r, val);
        tree[node] = tree[2*node] + tree[2*node+1];
    }
    
    long long query(int node, int start, int end, int l, int r) {
        push(node, start, end);
        
        if (r < start || end < l) return 0;
        if (l <= start && end <= r) return tree[node];
        
        int mid = (start + end) / 2;
        return query(2*node, start, mid, l, r) + 
               query(2*node+1, mid+1, end, l, r);
    }
    
public:
    SegmentTreeRangeSet(int size) {
        n = size;
        tree.assign(4 * n, 0);
        lazy.assign(4 * n, 0);
        hasLazy.assign(4 * n, false);
    }
    
    void rangeSet(int l, int r, long long val) {
        update(1, 0, n-1, l, r, val);
    }
    
    long long rangeSum(int l, int r) {
        return query(1, 0, n-1, l, r);
    }
};

Iterative Segment Tree

Version ที่ไม่ใช้ recursion (เร็วกว่า)

class IterativeSegTree {
private:
    vector<long long> tree;
    int n;
    
public:
    IterativeSegTree(int size) {
        n = size;
        tree.assign(2 * n, 0);
    }
    
    IterativeSegTree(vector<int>& arr) {
        n = arr.size();
        tree.resize(2 * n);
        
        // Copy array to leaves
        for (int i = 0; i < n; i++) {
            tree[n + i] = arr[i];
        }
        
        // Build tree
        for (int i = n - 1; i > 0; i--) {
            tree[i] = tree[2*i] + tree[2*i+1];
        }
    }
    
    void update(int idx, long long val) {
        idx += n;
        tree[idx] = val;
        
        while (idx > 1) {
            idx /= 2;
            tree[idx] = tree[2*idx] + tree[2*idx+1];
        }
    }
    
    long long query(int l, int r) { // [l, r]
        long long sum = 0;
        l += n; r += n + 1;
        
        while (l < r) {
            if (l & 1) sum += tree[l++];
            if (r & 1) sum += tree[--r];
            l /= 2; r /= 2;
        }
        
        return sum;
    }
};

Complexity

OperationTime
BuildO(n)O(n)
Point UpdateO(logn)O(\log n)
Range Update (Lazy)O(logn)O(\log n)
QueryO(logn)O(\log n)
SpaceO(n)O(n)

ตัวอย่างการใช้งาน

int main() {
    vector<int> arr = {1, 3, 5, 7, 9, 11};
    SegmentTree st(arr);
    
    // Query sum [1, 3]
    cout << st.query(1, 3) << endl; // 3 + 5 + 7 = 15
    
    // Update index 2 to 10
    st.update(2, 10);
    
    // Query sum [1, 3] again
    cout << st.query(1, 3) << endl; // 3 + 10 + 7 = 20
}

เลือกใช้อะไร?

สถานการณ์ใช้
Point update + Range queryBasic Segment Tree
Range update + Range querySegment Tree + Lazy
เน้น performanceIterative Segment Tree
Simple range sumFenwick Tree (ง่ายกว่า)

💡 Tip: ถ้าต้องการแค่ range sum และ point update ให้ใช้ Fenwick Tree แทน จะง่ายกว่า