Heavy-Light Decomposition
Bài viết này dựa trên bài viết trên VNOI, bài viết trên VNOI đã giải thích khá cụ thể nên ở đây chỉ tập trung vào phần code
Dưới đây và đoạn code xây dựng HLD:
#include <bits/stdc++.h>
using namespace std;
struct HLD {
int position, highest;
HLD() { position = highest = 0; }
HLD(int p, int h) {
position = p;
highest = h;
}
};
const int N = 1e5+7;
int numNode;
vector<int> adj[N], hldList;
int sz[N];
HLD hld[N];
void dfs(int u, int p) {
sz[u] = 1;
for (int v : adj[u]) if (v!=p) {
dfs(v, u);
sz[u] += sz[v];
}
}
void dfsHLD(int u, int p, int highest) {
hld[u] = HLD(hldList.size(), highest);
hldList.push_back(u);
int bigNode = -1;
for (int v : adj[u]) if (v!=p) {
if (bigNode == -1 || sz[bigNode] < sz[v]) {
bigNode = v;
}
}
if (bigNode == -1) return;
dfsHLD(bigNode, u, highest);
for (int v : adj[u]) if (v!=bigNode&&v!=p) {
dfsHLD(v, u, v);
}
}
void buildHLD() {
dfs(1, 0);
dfsHLD(1, 0, 1);
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
cin >> numNode;
for (int i = 1; i < numNode; ++i) {
int x, y;
cin >> x >> y;
adj[x].push_back(y);
adj[y].push_back(x);
}
buildHLD();
return 0;
}
Độ phức tạp: \(O(N)\)
Nhận xét: khi xử lý các truy vấn trên đường đi từ đỉnh u đến đỉnh v sẽ thường cần thêm LCA và thường kết hợp với các cấu trúc dữ liệu xử lí truy vấn trên đoạn như BIT, Segment Tree, Sparse Table, … khiến cho code rất dài
Độ phức tạp:
- Tiền xử lí: \(O(Nlog(N))\) (mất \(O(log(N))\) thường do có phần chuẩn bị LCA, khởi tạo cấu trúc dữ liệu)
- Xử lí truy vấn: \(O(log(N)*C)\) (mất \(O(log(N))\) do nhảy đoạn, \(O(C)\) là độ phức tạp khi xử lí một đoạn, thường là \(O(logN)\))
Để hiểu thêm về cách xử lí truy vấn, có thể tham khảo code một số bài tập dưới đây:
solution
#include <bits/stdc++.h>
using namespace std;
struct HLD {
int position, highest;
HLD() { position = highest = 0; }
HLD(int p, int h) {
position = p;
highest = h;
}
};
struct Edge {
int x, y, w;
Edge() { x = y = w = 0; }
Edge(int _x, int _y, int _w) {
x = _x; y = _y; w = _w;
}
};
struct segtree {
int size;
vector<int> node;
void init(int numNode) {
size = 1;
while (size < numNode) size<<=1;
node.assign(size<<1,0);
}
void update(int id, int l, int r, int i, int k) {
if (l>i||r<i) return;
if (l==r) {
node[id] = k;
return;
}
int m = l+r>>1;
update(id<<1,l,m,i,k);
update(id<<1|1,m+1,r,i,k);
node[id] = max(node[id<<1],node[id<<1|1]);
}
void update(int i, int k) {
update(1,0,size-1,i,k);
}
int getMax(int id, int l, int r, int x, int y) {
if (l>y||r<x) return INT_MIN;
if (l>=x&&r<=y) return node[id];
int m = l+r>>1;
return max(getMax(id<<1,l,m,x,y),getMax(id<<1|1,m+1,r,x,y));
}
int getMax(int x, int y) {
if (x > y) return INT_MIN;
return getMax(1,0,size-1,x,y);
}
} st;
const int N = 1e4+7;
const int LOG = 15;
int numNode, sz[N], anc[N][LOG], depth[N];
vector<int> adj[N], hldList;
vector<Edge> edge;
HLD hld[N];
void dfs(int u, int p) {
for (int i = 1; i < LOG; ++i) {
anc[u][i] = anc[anc[u][i-1]][i-1];
}
sz[u] = 1;
for (int i : adj[u]) {
int v = u ^ edge[i].x ^ edge[i].y;
if (v != p) {
anc[v][0] = u;
depth[v] = depth[u] + 1;
dfs(v, u);
sz[u] += sz[v];
}
}
}
void dfsHLD(int u, int p, int highest) {
hld[u] = HLD(hldList.size(), highest);
hldList.push_back(u);
int bigNode = -1;
for (int i : adj[u]) {
int v = u ^ edge[i].x ^ edge[i].y;
if (v != p) {
if (bigNode == -1 || sz[bigNode] < sz[v]) {
bigNode = v;
}
}
}
if (bigNode == -1) return;
dfsHLD(bigNode, u, highest);
for (int i : adj[u]) {
int v = u ^ edge[i].x ^ edge[i].y;
if (v!=p && v!=bigNode) {
dfsHLD(v, u, v);
}
}
}
void buildHLD() {
depth[0] = -1;
dfs(1,0);
dfsHLD(1, 0, 1);
st.init(numNode);
for (int i = 0; i < edge.size(); ++i) {
if (sz[edge[i].x] < sz[edge[i].y]) swap(edge[i].x,edge[i].y);
st.update(hld[edge[i].y].position, edge[i].w);
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u,v);
for (int i = LOG-1; i >= 0; --i) {
if (depth[anc[u][i]] >= depth[v]) {
u = anc[u][i];
}
}
if (u == v) return u;
for (int i = LOG-1; i >= 0; --i) {
if (anc[u][i] != anc[v][i]) {
u = anc[u][i];
v = anc[v][i];
}
}
return anc[u][0];
}
void maximize(int&x, int y) {
if (x < y) x = y;
}
void update(int id, int val) {
st.update(hld[edge[id].y].position, val);
}
int getMax(int u, int v) {
int k = lca(u,v), maxs = INT_MIN;
for (; hld[u].highest != hld[k].highest; u = anc[hld[u].highest][0])
maximize(maxs, st.getMax(hld[hld[u].highest].position, hld[u].position));
maximize(maxs, st.getMax(hld[k].position+1, hld[u].position));
for (; hld[v].highest != hld[k].highest; v = anc[hld[v].highest][0])
maximize(maxs, st.getMax(hld[hld[v].highest].position, hld[v].position));
maximize(maxs, st.getMax(hld[k].position+1, hld[v].position));
return maxs;
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int nTest; cin >> nTest;
while (nTest--) {
cin >> numNode;
for (int i = 1; i < numNode; ++i) {
int x, y, w;
cin >> x >> y >> w;
adj[x].push_back(edge.size());
adj[y].push_back(edge.size());
edge.emplace_back(x, y, w);
}
buildHLD();
string t;
while ((cin >> t) && t != "DONE") {
if (t == "CHANGE") {
int id, val;
cin >> id >> val;
update(id-1, val);
}
else {
int u, v;
cin >> u >> v;
cout << getMax(u,v) << '\n';
}
}
// reset
hldList.clear();
edge.clear();
for (int i = 1; i <= numNode; ++i) {
adj[i] = vector<int>();
}
}
return 0;
}
solution
#include <bits/stdc++.h>
using namespace std;
struct HLD {
int position, highest;
HLD() { position = highest = 0; }
HLD(int p, int h) {
position = p;
highest = h;
}
};
struct segtree {
int size;
vector<int> node;
void init(int numNode) {
size = 1;
while (size < numNode) size<<=1;
node.assign(size<<1,INT_MAX);
}
void update(int id, int l, int r, int i) {
if (l>i||r<i) return;
if (l==r) {
node[id] = node[id] == INT_MAX ? i : INT_MAX;
return;
}
int m = l+r>>1;
update(id<<1,l,m,i);
update(id<<1|1,m+1,r,i);
node[id] = min(node[id<<1],node[id<<1|1]);
}
void update(int i) {
update(1,0,size-1,i);
}
int getMin(int id, int l, int r, int x, int y) {
if (l>y||r<x) return INT_MAX;
if (l>=x&&r<=y) return node[id];
int m = l+r>>1;
return min(getMin(id<<1,l,m,x,y),getMin(id<<1|1,m+1,r,x,y));
}
int getMin(int x, int y) {
if (x > y) return INT_MAX;
return getMin(1,0,size-1,x,y);
}
} st;
const int N = 1e5+7;
int numNode, numQue, sz[N], par[N];
vector<int> adj[N], hldList;
HLD hld[N];
void dfs(int u, int p) {
sz[u] = 1;
for (int v : adj[u]) if (v != p) {
par[v] = u;
dfs(v, u);
sz[u] += sz[v];
}
}
void dfsHLD(int u, int p, int highest) {
hld[u] = HLD(hldList.size(), highest);
hldList.push_back(u);
int bigNode = -1;
for (int v : adj[u]) if (v != p) {
if (bigNode == -1 || sz[bigNode] < sz[v]) {
bigNode = v;
}
}
if (bigNode == -1) return;
dfsHLD(bigNode, u, highest);
for (int v : adj[u]) if (v!=p&&v!=bigNode) {
dfsHLD(v, u, v);
}
}
void buildHLD() {
dfs(1,0);
dfsHLD(1, 0, 1);
st.init(numNode);
}
void minimize(int&x, int y) {
if (x > y) x = y;
}
void update(int u) {
st.update(hld[u].position);
}
int FindFirst(int u) {
int mins = INT_MAX;
for (; hld[u].highest != hld[1].highest; u = par[hld[u].highest])
minimize(mins, st.getMin(hld[hld[u].highest].position, hld[u].position));
minimize(mins, st.getMin(hld[1].position, hld[u].position));
return mins == INT_MAX ? -1 : hldList[mins];
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
cin >> numNode >> numQue;
for (int i = 1; i < numNode; ++i) {
int x, y;
cin >> x >> y;
adj[x].push_back(y);
adj[y].push_back(x);
}
buildHLD();
while (numQue--) {
int t, u;
cin >> t >> u;
if (t==0) update(u);
else cout << FindFirst(u) << '\n';
}
return 0;
}
solution
#include <bits/stdc++.h>
using namespace std;
struct HLD {
int position, highest;
HLD() { position = highest = 0; }
HLD(int p, int h) {
position = p;
highest = h;
}
};
struct Edge {
int x, y, w;
Edge() { x = y = w = 0; }
Edge(int _x, int _y, int _w) {
x = _x; y = _y; w = _w;
}
};
struct segtree {
int size;
vector<int> maxs, mins, lazy;
void init(int numNode) {
size = 1;
while (size < numNode) size<<=1;
maxs.assign(size<<1,0);
mins.assign(size<<1,0);
lazy.assign(size<<1,1);
}
void pushDown(int id) {
if (lazy[id] != 1) {
lazy[id<<1] *= -1;
lazy[id<<1|1] *= -1;
maxs[id<<1] *= -1;
maxs[id<<1|1] *= -1;
mins[id<<1] *= -1;
mins[id<<1|1] *= -1;
swap(maxs[id<<1],mins[id<<1]);
swap(maxs[id<<1|1],mins[id<<1|1]);
lazy[id] = 1;
}
}
void update(int id, int l, int r, int i, int k) {
if (l>i||r<i) return;
if (l==r) {
mins[id] = maxs[id] = k;
return;
}
pushDown(id);
int m = l+r>>1;
update(id<<1,l,m,i,k);
update(id<<1|1,m+1,r,i,k);
mins[id] = min(mins[id<<1],mins[id<<1|1]);
maxs[id] = max(maxs[id<<1],maxs[id<<1|1]);
}
void update(int i, int k) {
update(1,0,size-1,i,k);
}
void Negate(int id, int l, int r, int x, int y) {
if (l>y||r<x) return;
if (l>=x&&r<=y) {
lazy[id] *= -1;
mins[id] *= -1;
maxs[id] *= -1;
swap(mins[id], maxs[id]);
return;
}
pushDown(id);
int m = l+r>>1;
Negate(id<<1,l,m,x,y);
Negate(id<<1|1,m+1,r,x,y);
mins[id] = min(mins[id<<1],mins[id<<1|1]);
maxs[id] = max(maxs[id<<1],maxs[id<<1|1]);
}
void Negate(int x, int y) {
if (x > y) return;
Negate(1,0,size-1,x,y);
}
int getMax(int id, int l, int r, int x, int y) {
if (l>y||r<x) return INT_MIN;
if (l>=x&&r<=y) return maxs[id];
pushDown(id);
int m = l+r>>1;
return max(getMax(id<<1,l,m,x,y),getMax(id<<1|1,m+1,r,x,y));
}
int getMax(int x, int y) {
if (x > y) return INT_MIN;
return getMax(1,0,size-1,x,y);
}
} st;
const int N = 1e4+7;
const int LOG = 15;
int numNode, sz[N], anc[N][LOG], depth[N];
vector<int> adj[N], hldList;
vector<Edge> edge;
HLD hld[N];
void dfs(int u, int p) {
for (int i = 1; i < LOG; ++i) {
anc[u][i] = anc[anc[u][i-1]][i-1];
}
sz[u] = 1;
for (int i : adj[u]) {
int v = u ^ edge[i].x ^ edge[i].y;
if (v != p) {
anc[v][0] = u;
depth[v] = depth[u] + 1;
dfs(v, u);
sz[u] += sz[v];
}
}
}
void dfsHLD(int u, int p, int highest) {
hld[u] = HLD(hldList.size(), highest);
hldList.push_back(u);
int bigNode = -1;
for (int i : adj[u]) {
int v = u ^ edge[i].x ^ edge[i].y;
if (v != p) {
if (bigNode == -1 || sz[bigNode] < sz[v]) {
bigNode = v;
}
}
}
if (bigNode == -1) return;
dfsHLD(bigNode, u, hld[u].highest);
for (int i : adj[u]) {
int v = u ^ edge[i].x ^ edge[i].y;
if (v!=p && v!=bigNode) {
dfsHLD(v, u, v);
}
}
}
void buildHLD() {
depth[0] = -1;
dfs(1,0);
dfsHLD(1, 0, 1);
st.init(numNode);
for (int i = 0; i < edge.size(); ++i) {
if (sz[edge[i].x] < sz[edge[i].y]) swap(edge[i].x,edge[i].y);
st.update(hld[edge[i].y].position, edge[i].w);
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u,v);
for (int i = LOG-1; i >= 0; --i) {
if (depth[anc[u][i]] >= depth[v]) {
u = anc[u][i];
}
}
if (u == v) return u;
for (int i = LOG-1; i >= 0; --i) {
if (anc[u][i] != anc[v][i]) {
u = anc[u][i];
v = anc[v][i];
}
}
return anc[u][0];
}
void maximize(int&x, int y) {
if (x < y) x = y;
}
void update(int id, int val) {
st.update(hld[edge[id].y].position, val);
}
void Negate(int u, int v) {
int k = lca(u,v);
for (; hld[u].highest != hld[k].highest; u = anc[hld[u].highest][0])
st.Negate(hld[hld[u].highest].position, hld[u].position);
st.Negate(hld[k].position+1, hld[u].position);
for (; hld[v].highest != hld[k].highest; v = anc[hld[v].highest][0])
st.Negate(hld[hld[v].highest].position, hld[v].position);
st.Negate(hld[k].position+1, hld[v].position);
}
int getMax(int u, int v) {
if (u == v) return 0;
int k = lca(u,v), maxs = INT_MIN;
for (; hld[u].highest != hld[k].highest; u = anc[hld[u].highest][0])
maximize(maxs, st.getMax(hld[hld[u].highest].position, hld[u].position));
maximize(maxs, st.getMax(hld[k].position+1, hld[u].position));
for (; hld[v].highest != hld[k].highest; v = anc[hld[v].highest][0])
maximize(maxs, st.getMax(hld[hld[v].highest].position, hld[v].position));
maximize(maxs, st.getMax(hld[k].position+1, hld[v].position));
return maxs;
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
int nTest; cin >> nTest;
while (nTest--) {
cin >> numNode;
for (int i = 1; i < numNode; ++i) {
int x, y, w;
cin >> x >> y >> w;
adj[x].push_back(edge.size());
adj[y].push_back(edge.size());
edge.emplace_back(x, y, w);
}
buildHLD();
string t;
while ((cin >> t) && t != "DONE") {
if (t == "CHANGE") {
int id, val;
cin >> id >> val;
update(id-1, val);
}
else if (t == "NEGATE") {
int u, v;
cin >> u >> v;
Negate(u,v);
}
else {
int u, v;
cin >> u >> v;
cout << getMax(u,v) << '\n';
}
}
// reset
hldList.clear();
edge.clear();
for (int i = 1; i <= numNode; ++i) {
adj[i] = vector<int>();
}
}
return 0;
}
solution
#include <bits/stdc++.h>
using namespace std;
struct HLD {
int position, highest, last;
HLD() { position = highest = 0; }
HLD(int p, int h) {
position = p;
highest = h;
}
};
const int N = 1e5+7;
const int LOG = 18;
int numNode, numQue;
int anc[N][LOG], depth[N], sz[N];
vector<int> adj[N], hldList;
HLD hld[N];
struct segtree {
int size;
vector<long long> lazy, val;
vector<pair<long long, long long>> node;
void init(int n) {
size = 1;
while(size < n) size <<= 1;
node.assign(size<<1, make_pair(0,0));
val.assign(size<<1,0);
lazy.assign(size<<1,0);
build(1,0,size-1);
}
void build(int id, int l, int r) {
if (l == r) {
if (l < numNode) {
val[id] = sz[hldList[l]];
}
return;
}
int m = l+r>>1;
build(id<<1,l,m);
build(id<<1|1,m+1,r);
val[id] = val[id<<1] + val[id<<1|1];
}
void pushDown(int id, int len) {
if (lazy[id] == 0) return;
for (int i = id<<1; i <= (id<<1|1); ++i) {
lazy[i] += lazy[id];
node[i].first += val[i] * lazy[id];
node[i].second += lazy[id] * (len>>1);
}
lazy[id] = 0;
}
void add(int id, int l, int r, int x, int y, int k) {
if (l>y||r<x) return;
if (l>=x&&r<=y) {
node[id].first += val[id] * k;
node[id].second += 1LL * k * (r-l+1);
lazy[id] += k;
return;
}
pushDown(id, r-l+1);
int m = l+r>>1;
add(id<<1,l,m,x,y,k);
add(id<<1|1,m+1,r,x,y,k);
node[id].first = node[id<<1].first + node[id<<1|1].first;
node[id].second = node[id<<1].second + node[id<<1|1].second;
}
void add(int x, int y, int k) {
if (x > y) return;
add(1,0,size-1,x,y,k);
}
long long sum(int id, int l, int r, int x, int y) {
if (l>y||r<x) return 0;
if (l>=x&&r<=y) return node[id].first;
pushDown(id, r-l+1);
int m = l+r>>1;
return sum(id<<1,l,m,x,y) + sum(id<<1|1,m+1,r,x,y);
}
long long sum(int x, int y) {
return sum(1,0,size-1,x,y);
}
long long sum2(int id, int l, int r, int x, int y) {
if (l>y||r<x) return 0;
if (l>=x&&r<=y) return node[id].second;
pushDown(id, r-l+1);
int m = l+r>>1;
return sum2(id<<1,l,m,x,y) + sum2(id<<1|1,m+1,r,x,y);
}
long long sum2(int x, int y) {
return sum2(1,0,size-1,x,y);
}
} st;
void dfs(int u, int p) {
sz[u] = 1;
for (int i = 1; i < LOG; ++i) {
anc[u][i] = anc[anc[u][i-1]][i-1];
}
for (int v : adj[u]) if (v != p) {
anc[v][0] = u;
depth[v] = depth[u] + 1;
dfs(v, u);
sz[u] += sz[v];
}
}
void dfsHLD(int u, int p, int highest) {
hld[u] = HLD(hldList.size(), highest);
hldList.push_back(u);
int bigNode = -1;
for (int v : adj[u]) if (v != p) {
if (bigNode == -1 || sz[bigNode] < sz[v]) {
bigNode = v;
}
}
if (bigNode != -1) {
dfsHLD(bigNode, u, highest);
for (int v : adj[u]) {
if (v != p && v != bigNode) {
dfsHLD(v, u, v);
}
}
}
hld[u].last = hldList.size();
}
void buildHLD() {
depth[0] = -1;
dfs(1,-1);
dfsHLD(1,-1,1);
// subNode of u in [hld[u].position, hld[u].last)
st.init(numNode);
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u,v);
for (int i = LOG-1; i >= 0; --i) {
if (depth[anc[u][i]] >= depth[v]) {
u = anc[u][i];
}
}
if (u == v) return u;
for (int i = LOG-1; i >= 0; --i) {
if (anc[u][i] != anc[v][i]) {
u = anc[u][i];
v = anc[v][i];
}
}
return anc[u][0];
}
void add(int u, int v, int k) {
int par = lca(u,v);
for (; hld[u].highest != hld[par].highest; u = anc[hld[u].highest][0]) {
st.add(hld[hld[u].highest].position, hld[u].position, k);
}
st.add(hld[par].position, hld[u].position, k);
for (; hld[v].highest != hld[par].highest; v = anc[hld[v].highest][0]) {
st.add(hld[hld[v].highest].position, hld[v].position, k);
}
st.add(hld[par].position+1, hld[v].position, k);
}
long long sum(int u, int num) {
long long s = 0;
for (; hld[u].highest != hld[1].highest; u = anc[hld[u].highest][0]) {
s += st.sum2(hld[hld[u].highest].position, hld[u].position);
}
s += st.sum2(hld[1].position, hld[u].position);
return s * num;
}
long long getVal(int u) {
long long val = st.sum(hld[u].position, hld[u].last-1);
if (u != 1) val += sum(anc[u][0], sz[u]);
return val;
}
int main() {
cin.tie(nullptr)->sync_with_stdio(false);
cin >> numNode;
for (int i = 1; i < numNode; ++i) {
int x, y;
cin >> x >> y;
adj[x].push_back(y);
adj[y].push_back(x);
}
buildHLD();
cin >> numQue;
while (numQue--) {
int t; cin >> t;
if (t == 1) {
int u, v, k;
cin >> u >> v >> k;
add(u, v, k);
}
else {
int u; cin >> u;
cout << getVal(u) % 10009 << '\n';
}
}
return 0;
}