Heavy-Light Decomposition

Bài viết này dựa trên bài viết trên VNOI, bài viết trên VNOI đã giải thích khá cụ thể nên ở đây chỉ tập trung vào phần code

Dưới đây và đoạn code xây dựng HLD:

#include <bits/stdc++.h>

using namespace std;

struct HLD {
    int position, highest;
    HLD() { position = highest = 0; }
    HLD(int p, int h) {
        position = p;
        highest = h;
    }
};

const int N = 1e5+7;
int numNode;
vector<int> adj[N], hldList;
int sz[N];
HLD hld[N];

void dfs(int u, int p) {
    sz[u] = 1;
    for (int v : adj[u]) if (v!=p) {
        dfs(v, u);
        sz[u] += sz[v];
    }
}

void dfsHLD(int u, int p, int highest) {
    hld[u] = HLD(hldList.size(), highest);
    hldList.push_back(u);
    int bigNode = -1;
    for (int v : adj[u]) if (v!=p) {
        if (bigNode == -1 || sz[bigNode] < sz[v]) {
            bigNode = v;
        }
    }
    if (bigNode == -1) return;
    dfsHLD(bigNode, u, highest);
    for (int v : adj[u]) if (v!=bigNode&&v!=p) {
        dfsHLD(v, u, v);
    }
}

void buildHLD() {
    dfs(1, 0);
    dfsHLD(1, 0, 1);
}

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> numNode;
    for (int i = 1; i < numNode; ++i) {
        int x, y;
        cin >> x >> y;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }
    buildHLD();
    return 0;
}

Độ phức tạp: \(O(N)\)

Nhận xét: khi xử lý các truy vấn trên đường đi từ đỉnh u đến đỉnh v sẽ thường cần thêm LCA và thường kết hợp với các cấu trúc dữ liệu xử lí truy vấn trên đoạn như BIT, Segment Tree, Sparse Table, … khiến cho code rất dài

Độ phức tạp:

Để hiểu thêm về cách xử lí truy vấn, có thể tham khảo code một số bài tập dưới đây:

solution
#include <bits/stdc++.h>
 
using namespace std;
 
struct HLD {
    int position, highest;
    HLD() { position = highest = 0; }
    HLD(int p, int h) {
        position = p;
        highest = h;
    }
};
 
struct Edge {
    int x, y, w;
    Edge() { x = y = w = 0; }
    Edge(int _x, int _y, int _w) {
        x = _x; y = _y; w = _w;
    }
};
 
struct segtree {
    int size;
    vector<int> node;
    void init(int numNode) {
        size = 1;
        while (size < numNode) size<<=1;
        node.assign(size<<1,0);
    }
    void update(int id, int l, int r, int i, int k) {
        if (l>i||r<i) return;
        if (l==r) {
            node[id] = k;
            return;
        }
        int m = l+r>>1;
        update(id<<1,l,m,i,k);
        update(id<<1|1,m+1,r,i,k);
        node[id] = max(node[id<<1],node[id<<1|1]);
    }
    void update(int i, int k) {
        update(1,0,size-1,i,k);
    }
    int getMax(int id, int l, int r, int x, int y) {
        if (l>y||r<x) return INT_MIN;
        if (l>=x&&r<=y) return node[id];
        int m = l+r>>1;
        return max(getMax(id<<1,l,m,x,y),getMax(id<<1|1,m+1,r,x,y));
    }
    int getMax(int x, int y) {
        if (x > y) return INT_MIN;
        return getMax(1,0,size-1,x,y);
    }
} st;
 
const int N = 1e4+7;
const int LOG = 15;
int numNode, sz[N], anc[N][LOG], depth[N];
vector<int> adj[N], hldList;
vector<Edge> edge;
HLD hld[N];
 
void dfs(int u, int p) {
    for (int i = 1; i < LOG; ++i) {
        anc[u][i] = anc[anc[u][i-1]][i-1];
    }
    sz[u] = 1;
    for (int i : adj[u]) {
        int v = u ^ edge[i].x ^ edge[i].y;
        if (v != p) {
            anc[v][0] = u;
            depth[v] = depth[u] + 1;
            dfs(v, u);
            sz[u] += sz[v];
        }
    }
}
 
void dfsHLD(int u, int p, int highest) {
    hld[u] = HLD(hldList.size(), highest);
    hldList.push_back(u);
    int bigNode = -1;
    for (int i : adj[u]) {
        int v = u ^ edge[i].x ^ edge[i].y;
        if (v != p) {
            if (bigNode == -1 || sz[bigNode] < sz[v]) {
                bigNode = v;
            }
        }
    }
    if (bigNode == -1) return;
    dfsHLD(bigNode, u, highest);
    for (int i : adj[u]) {
        int v = u ^ edge[i].x ^ edge[i].y;
        if (v!=p && v!=bigNode) {
            dfsHLD(v, u, v);
        }
    }
}
 
void buildHLD() {
    depth[0] = -1;
    dfs(1,0);
    dfsHLD(1, 0, 1);
    st.init(numNode);
    for (int i = 0; i < edge.size(); ++i) {
        if (sz[edge[i].x] < sz[edge[i].y]) swap(edge[i].x,edge[i].y);
        st.update(hld[edge[i].y].position, edge[i].w);
    }
}
 
int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u,v);
    for (int i = LOG-1; i >= 0; --i) {
        if (depth[anc[u][i]] >= depth[v]) {
            u = anc[u][i];
        }
    }
    if (u == v) return u;
    for (int i = LOG-1; i >= 0; --i) {
        if (anc[u][i] != anc[v][i]) {
            u = anc[u][i];
            v = anc[v][i];
        }
    }
    return anc[u][0];
}
 
void maximize(int&x, int y) {
    if (x < y) x = y;
}
 
void update(int id, int val) {
    st.update(hld[edge[id].y].position, val);
}
 
int getMax(int u, int v) {
    int k = lca(u,v), maxs = INT_MIN;
    for (; hld[u].highest != hld[k].highest; u = anc[hld[u].highest][0])
        maximize(maxs, st.getMax(hld[hld[u].highest].position, hld[u].position));
    maximize(maxs, st.getMax(hld[k].position+1, hld[u].position));
    for (; hld[v].highest != hld[k].highest; v = anc[hld[v].highest][0])
        maximize(maxs, st.getMax(hld[hld[v].highest].position, hld[v].position));
    maximize(maxs, st.getMax(hld[k].position+1, hld[v].position));
    return maxs;
}
 
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    int nTest; cin >> nTest;
    while (nTest--) {
        cin >> numNode;
        for (int i = 1; i < numNode; ++i) {
            int x, y, w;
            cin >> x >> y >> w;
            adj[x].push_back(edge.size());
            adj[y].push_back(edge.size());
            edge.emplace_back(x, y, w);
        }
        buildHLD();
        string t;
        while ((cin >> t) && t != "DONE") {
            if (t == "CHANGE") {
                int id, val;
                cin >> id >> val;
                update(id-1, val);
            }
            else {
                int u, v;
                cin >> u >> v;
                cout << getMax(u,v) << '\n';
            }
        }
        // reset
        hldList.clear();
        edge.clear();
        for (int i = 1; i <= numNode; ++i) {
            adj[i] = vector<int>();
        }
    }
    return 0;   
}
solution
#include <bits/stdc++.h>

using namespace std;

struct HLD {
    int position, highest;
    HLD() { position = highest = 0; }
    HLD(int p, int h) {
        position = p;
        highest = h;
    }
};

struct segtree {
    int size;
    vector<int> node;
    void init(int numNode) {
        size = 1;
        while (size < numNode) size<<=1;
        node.assign(size<<1,INT_MAX);
    }
    void update(int id, int l, int r, int i) {
        if (l>i||r<i) return;
        if (l==r) {
            node[id] = node[id] == INT_MAX ? i : INT_MAX;
            return;
        }
        int m = l+r>>1;
        update(id<<1,l,m,i);
        update(id<<1|1,m+1,r,i);
        node[id] = min(node[id<<1],node[id<<1|1]);
    }
    void update(int i) {
        update(1,0,size-1,i);
    }
    int getMin(int id, int l, int r, int x, int y) {
        if (l>y||r<x) return INT_MAX;
        if (l>=x&&r<=y) return node[id];
        int m = l+r>>1;
        return min(getMin(id<<1,l,m,x,y),getMin(id<<1|1,m+1,r,x,y));
    }
    int getMin(int x, int y) {
        if (x > y) return INT_MAX;
        return getMin(1,0,size-1,x,y);
    }
} st;

const int N = 1e5+7;
int numNode, numQue, sz[N], par[N];
vector<int> adj[N], hldList;
HLD hld[N];

void dfs(int u, int p) {
    sz[u] = 1;
    for (int v : adj[u]) if (v != p) {
        par[v] = u;
        dfs(v, u);
        sz[u] += sz[v];
    }
}

void dfsHLD(int u, int p, int highest) {
    hld[u] = HLD(hldList.size(), highest);
    hldList.push_back(u);
    int bigNode = -1;
    for (int v : adj[u]) if (v != p) {
        if (bigNode == -1 || sz[bigNode] < sz[v]) {
            bigNode = v;
        }
    }
    if (bigNode == -1) return;
    dfsHLD(bigNode, u, highest);
    for (int v : adj[u]) if (v!=p&&v!=bigNode) {
        dfsHLD(v, u, v);
    }
}

void buildHLD() {
    dfs(1,0);
    dfsHLD(1, 0, 1);
    st.init(numNode);
}

void minimize(int&x, int y) {
    if (x > y) x = y;
}

void update(int u) {
    st.update(hld[u].position);
}

int FindFirst(int u) {
    int mins = INT_MAX;
    for (; hld[u].highest != hld[1].highest; u = par[hld[u].highest])
        minimize(mins, st.getMin(hld[hld[u].highest].position, hld[u].position));
    minimize(mins, st.getMin(hld[1].position, hld[u].position));
    return mins == INT_MAX ? -1 : hldList[mins];
}

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> numNode >> numQue;
    for (int i = 1; i < numNode; ++i) {
        int x, y;
        cin >> x >> y;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }
    buildHLD();
    while (numQue--) {
        int t, u;
        cin >> t >> u;
        if (t==0) update(u);
        else cout << FindFirst(u) << '\n';
    }
    return 0;   
}
solution
#include <bits/stdc++.h>
 
using namespace std;
 
struct HLD {
    int position, highest;
    HLD() { position = highest = 0; }
    HLD(int p, int h) {
        position = p;
        highest = h;
    }
};
 
struct Edge {
    int x, y, w;
    Edge() { x = y = w = 0; }
    Edge(int _x, int _y, int _w) {
        x = _x; y = _y; w = _w;
    }
};
 
struct segtree {
    int size;
    vector<int> maxs, mins, lazy;
    void init(int numNode) {
        size = 1;
        while (size < numNode) size<<=1;
        maxs.assign(size<<1,0);
        mins.assign(size<<1,0);
        lazy.assign(size<<1,1);
    }
    void pushDown(int id) {
        if (lazy[id] != 1) {
            lazy[id<<1] *= -1;
            lazy[id<<1|1] *= -1;
            maxs[id<<1] *= -1;
            maxs[id<<1|1] *= -1;
            mins[id<<1] *= -1;
            mins[id<<1|1] *= -1;
            swap(maxs[id<<1],mins[id<<1]);
            swap(maxs[id<<1|1],mins[id<<1|1]);
            lazy[id] = 1;
        }
    }
    void update(int id, int l, int r, int i, int k) {
        if (l>i||r<i) return;
        if (l==r) {
            mins[id] = maxs[id] = k;
            return;
        }
        pushDown(id);
        int m = l+r>>1;
        update(id<<1,l,m,i,k);
        update(id<<1|1,m+1,r,i,k);
        mins[id] = min(mins[id<<1],mins[id<<1|1]);
        maxs[id] = max(maxs[id<<1],maxs[id<<1|1]);
    }
    void update(int i, int k) {
        update(1,0,size-1,i,k);
    }
    void Negate(int id, int l, int r, int x, int y) {
        if (l>y||r<x) return;
        if (l>=x&&r<=y) {
            lazy[id] *= -1;
            mins[id] *= -1;
            maxs[id] *= -1;
            swap(mins[id], maxs[id]);
            return;
        }
        pushDown(id);
        int m = l+r>>1;
        Negate(id<<1,l,m,x,y);
        Negate(id<<1|1,m+1,r,x,y);
        mins[id] = min(mins[id<<1],mins[id<<1|1]);
        maxs[id] = max(maxs[id<<1],maxs[id<<1|1]);
    }
    void Negate(int x, int y) {
        if (x > y) return;
        Negate(1,0,size-1,x,y);
    }
    int getMax(int id, int l, int r, int x, int y) {
        if (l>y||r<x) return INT_MIN;
        if (l>=x&&r<=y) return maxs[id];
        pushDown(id);
        int m = l+r>>1;
        return max(getMax(id<<1,l,m,x,y),getMax(id<<1|1,m+1,r,x,y));
    }
    int getMax(int x, int y) {
        if (x > y) return INT_MIN;
        return getMax(1,0,size-1,x,y);
    }
} st;
 
const int N = 1e4+7;
const int LOG = 15;
int numNode, sz[N], anc[N][LOG], depth[N];
vector<int> adj[N], hldList;
vector<Edge> edge;
HLD hld[N];
 
void dfs(int u, int p) {
    for (int i = 1; i < LOG; ++i) {
        anc[u][i] = anc[anc[u][i-1]][i-1];
    }
    sz[u] = 1;
    for (int i : adj[u]) {
        int v = u ^ edge[i].x ^ edge[i].y;
        if (v != p) {
            anc[v][0] = u;
            depth[v] = depth[u] + 1;
            dfs(v, u);
            sz[u] += sz[v];
        }
    }
}
 
void dfsHLD(int u, int p, int highest) {
    hld[u] = HLD(hldList.size(), highest);
    hldList.push_back(u);
    int bigNode = -1;
    for (int i : adj[u]) {
        int v = u ^ edge[i].x ^ edge[i].y;
        if (v != p) {
            if (bigNode == -1 || sz[bigNode] < sz[v]) {
                bigNode = v;
            }
        }
    }
    if (bigNode == -1) return;
    dfsHLD(bigNode, u, hld[u].highest);
    for (int i : adj[u]) {
        int v = u ^ edge[i].x ^ edge[i].y;
        if (v!=p && v!=bigNode) {
            dfsHLD(v, u, v);
        }
    }
}
 
void buildHLD() {
    depth[0] = -1;
    dfs(1,0);
    dfsHLD(1, 0, 1);
    st.init(numNode);
    for (int i = 0; i < edge.size(); ++i) {
        if (sz[edge[i].x] < sz[edge[i].y]) swap(edge[i].x,edge[i].y);
        st.update(hld[edge[i].y].position, edge[i].w);
    }
}
 
int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u,v);
    for (int i = LOG-1; i >= 0; --i) {
        if (depth[anc[u][i]] >= depth[v]) {
            u = anc[u][i];
        }
    }
    if (u == v) return u;
    for (int i = LOG-1; i >= 0; --i) {
        if (anc[u][i] != anc[v][i]) {
            u = anc[u][i];
            v = anc[v][i];
        }
    }
    return anc[u][0];
}
 
void maximize(int&x, int y) {
    if (x < y) x = y;
}
 
void update(int id, int val) {
    st.update(hld[edge[id].y].position, val);
}

void Negate(int u, int v) {
    int k = lca(u,v);
    for (; hld[u].highest != hld[k].highest; u = anc[hld[u].highest][0])
        st.Negate(hld[hld[u].highest].position, hld[u].position);
    st.Negate(hld[k].position+1, hld[u].position);
    for (; hld[v].highest != hld[k].highest; v = anc[hld[v].highest][0])
        st.Negate(hld[hld[v].highest].position, hld[v].position);
    st.Negate(hld[k].position+1, hld[v].position);
}

int getMax(int u, int v) {
    if (u == v) return 0;
    int k = lca(u,v), maxs = INT_MIN;
    for (; hld[u].highest != hld[k].highest; u = anc[hld[u].highest][0])
        maximize(maxs, st.getMax(hld[hld[u].highest].position, hld[u].position));
    maximize(maxs, st.getMax(hld[k].position+1, hld[u].position));
    for (; hld[v].highest != hld[k].highest; v = anc[hld[v].highest][0])
        maximize(maxs, st.getMax(hld[hld[v].highest].position, hld[v].position));
    maximize(maxs, st.getMax(hld[k].position+1, hld[v].position));
    return maxs;
}
 
int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    int nTest; cin >> nTest;
    while (nTest--) {
        cin >> numNode;
        for (int i = 1; i < numNode; ++i) {
            int x, y, w;
            cin >> x >> y >> w;
            adj[x].push_back(edge.size());
            adj[y].push_back(edge.size());
            edge.emplace_back(x, y, w);
        }
        buildHLD();
        string t;
        while ((cin >> t) && t != "DONE") {
            if (t == "CHANGE") {
                int id, val;
                cin >> id >> val;
                update(id-1, val);
            }
            else if (t == "NEGATE") {
                int u, v;
                cin >> u >> v;
                Negate(u,v);
            }
            else {
                int u, v;
                cin >> u >> v;
                cout << getMax(u,v) << '\n';
            }
        }
        // reset
        hldList.clear();
        edge.clear();
        for (int i = 1; i <= numNode; ++i) {
            adj[i] = vector<int>();
        }
    }
    return 0;   
}
solution
#include <bits/stdc++.h>

using namespace std;

struct HLD {
    int position, highest, last;
    HLD() { position = highest = 0; }
    HLD(int p, int h) {
        position = p;
        highest = h;
    }
};

const int N = 1e5+7;
const int LOG = 18;
int numNode, numQue;
int anc[N][LOG], depth[N], sz[N];
vector<int> adj[N], hldList;
HLD hld[N];

struct segtree {
    int size;
    vector<long long> lazy, val;
    vector<pair<long long, long long>> node;
    void init(int n) {
        size = 1;
        while(size < n) size <<= 1;
        node.assign(size<<1, make_pair(0,0));
        val.assign(size<<1,0);
        lazy.assign(size<<1,0);
        build(1,0,size-1);
    }
    void build(int id, int l, int r) {
        if (l == r) {
            if (l < numNode) {
                val[id] = sz[hldList[l]];
            }
            return;
        }
        int m = l+r>>1;
        build(id<<1,l,m);
        build(id<<1|1,m+1,r);
        val[id] = val[id<<1] + val[id<<1|1];
    }
    void pushDown(int id, int len) {
        if (lazy[id] == 0) return;
        for (int i = id<<1; i <= (id<<1|1); ++i) {
            lazy[i] += lazy[id];
            node[i].first += val[i] * lazy[id];
            node[i].second += lazy[id] * (len>>1);
        }
        lazy[id] = 0;
    }
    void add(int id, int l, int r, int x, int y, int k) {
        if (l>y||r<x) return;
        if (l>=x&&r<=y) {
            node[id].first += val[id] * k;
            node[id].second += 1LL * k * (r-l+1);
            lazy[id] += k;
            return;
        }
        pushDown(id, r-l+1);
        int m = l+r>>1;
        add(id<<1,l,m,x,y,k);
        add(id<<1|1,m+1,r,x,y,k);
        node[id].first = node[id<<1].first + node[id<<1|1].first;
        node[id].second = node[id<<1].second + node[id<<1|1].second;
    }
    void add(int x, int y, int k) {
        if (x > y) return;
        add(1,0,size-1,x,y,k);
    }
    long long sum(int id, int l, int r, int x, int y) {
        if (l>y||r<x) return 0;
        if (l>=x&&r<=y) return node[id].first;
        pushDown(id, r-l+1);
        int m = l+r>>1;
        return sum(id<<1,l,m,x,y) + sum(id<<1|1,m+1,r,x,y);
    }
    long long sum(int x, int y) {
        return sum(1,0,size-1,x,y);
    }
    long long sum2(int id, int l, int r, int x, int y) {
        if (l>y||r<x) return 0;
        if (l>=x&&r<=y) return node[id].second;
        pushDown(id, r-l+1);
        int m = l+r>>1;
        return sum2(id<<1,l,m,x,y) + sum2(id<<1|1,m+1,r,x,y);
    }
    long long sum2(int x, int y) {
        return sum2(1,0,size-1,x,y);
    }
} st;

void dfs(int u, int p) {
    sz[u] = 1;
    for (int i = 1; i < LOG; ++i) {
        anc[u][i] = anc[anc[u][i-1]][i-1];
    }
    for (int v : adj[u]) if (v != p) {
        anc[v][0] = u;
        depth[v] = depth[u] + 1;
        dfs(v, u);
        sz[u] += sz[v];
    }
}

void dfsHLD(int u, int p, int highest) {
    hld[u] = HLD(hldList.size(), highest);
    hldList.push_back(u);
    int bigNode = -1;
    for (int v : adj[u]) if (v != p) {
        if (bigNode == -1 || sz[bigNode] < sz[v]) {
            bigNode = v;
        }
    }
    if (bigNode != -1) {
        dfsHLD(bigNode, u, highest);
        for (int v : adj[u]) {
            if (v != p && v != bigNode) {
                dfsHLD(v, u, v);
            }
        }
    }
    hld[u].last = hldList.size();
}

void buildHLD() {
    depth[0] = -1;
    dfs(1,-1);
    dfsHLD(1,-1,1);
    // subNode of u in [hld[u].position, hld[u].last)
    st.init(numNode);
}

int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u,v);
    for (int i = LOG-1; i >= 0; --i) {
        if (depth[anc[u][i]] >= depth[v]) {
            u = anc[u][i];
        }
    }
    if (u == v) return u;
    for (int i = LOG-1; i >= 0; --i) {
        if (anc[u][i] != anc[v][i]) {
            u = anc[u][i];
            v = anc[v][i];
        }
    }
    return anc[u][0];
}

void add(int u, int v, int k) {
    int par = lca(u,v);
    for (; hld[u].highest != hld[par].highest; u = anc[hld[u].highest][0]) {
        st.add(hld[hld[u].highest].position, hld[u].position, k);
    }
    st.add(hld[par].position, hld[u].position, k);
    for (; hld[v].highest != hld[par].highest; v = anc[hld[v].highest][0]) {
        st.add(hld[hld[v].highest].position, hld[v].position, k);
    }
    st.add(hld[par].position+1, hld[v].position, k);
}

long long sum(int u, int num) {
    long long s = 0;
    for (; hld[u].highest != hld[1].highest; u = anc[hld[u].highest][0]) {
        s += st.sum2(hld[hld[u].highest].position, hld[u].position);
    }
    s += st.sum2(hld[1].position, hld[u].position);
    return s * num;
}

long long getVal(int u) {
    long long val = st.sum(hld[u].position, hld[u].last-1);
    if (u != 1) val += sum(anc[u][0], sz[u]);
    return val;
}

int main() {
    cin.tie(nullptr)->sync_with_stdio(false);
    cin >> numNode;
    for (int i = 1; i < numNode; ++i) {
        int x, y;
        cin >> x >> y;
        adj[x].push_back(y);
        adj[y].push_back(x);
    }
    buildHLD();
    cin >> numQue;
    while (numQue--) {
        int t; cin >> t;
        if (t == 1) {
            int u, v, k;
            cin >> u >> v >> k;
            add(u, v, k);
        }
        else {
            int u; cin >> u;
            cout << getVal(u) % 10009 << '\n';
        }
    }
    return 0;
}

Nguồn tham khảo