【集训整理】最近公共祖先LCA 模板题


最近公共祖先:在有根树中,两个节点的最近的公共祖先

也就意味着,树上两个节点的最短距离就是他们的最近公共祖先到这两个节点距离之和

朴素算法

我们可以通过一遍dfs记录下每个节点的深度信息。

在查询的时候,先让深度大的点往上跳,直到两点深度相等。两点深度相等以后就一起一步一步往上跳,直到跳到同一个点。

显然这样的做法过于暴力,会T。

倍增算法

倍增:由于所有数都可以由 a12n+a22n1+...+an121+an20a_12^n + a_22^{n-1}+...+a_{n-1}2^1+a_n2^0表示,所以相比于从1跳到n,我们可以一次性跳2k(k<=log2n)2^k(k<=log_2n)步,然后逐步缩小k直到跳到n。这样就可以将复杂度降至o(logn)。

例:跳到14,14 = 8 + 4 + 2

朴素算法为什么会慢呢?

朴素算法在往上跳的时候是一步一步跳的,这样跳万一两点都很深,那么就会很慢。我们可以利用倍增的思想,一次跳一大步,然后再逐渐缩小步长逼近目标。

为了实现快速的一次跳2k2^k步的目标,我们采用空间换时间——在dfs结束的时候算出所有节点的2k2^k级祖先(根据数据范围决定k的最大值,这里选20)。

fa[i][j]fa[i][j]表示第i号节点的2j2^j级祖先,得到一开始的dfs如下:

void dfs(int now, int dep, int father) {
	// 预处理深度
	if (vis[now]) return;
	vis[now] = true;
	depth[now] = dep;
	fa[now][0] = father;
	for (int v : vec[now]) {
		dfs(v, dep + 1, now);
	}
}

这样就用dfs先记录下自己的202^0级祖先(直系父亲)。

接下来通过两个循环,做一个类似于dp一样的转移来求fa:

怎么求fa:

通过一个简单的原理: 2i=2i1+2i12^i = 2^{i-1} + 2^{i-1}

可得到:fa[now][i]=fa[fa[now][i1]][i1];fa[now][i] = fa[fa[now][i-1]][i-1];

意思是now的2i2^i祖先等于now的2i12^{i-1}祖先的2i12^{i-1}祖先

for (int j = 1; j <= 20; j++)
	for (int i = 1; i <= n; i++)
		fa[i][j] = fa[fa[i][j - 1]][j - 1];

这里由于直接固定了遍历到2的20次方,有可能跳的太多,会超过树的根,也就是0,所以在lca的时候要把0的情况剔除。

接下来就是倍增跳LCA的过程了。

跳LCA有两个步骤:跳到同一深度,一起往上跳。

跳到同一深度后,如果两个节点相同,直接返回;否则两个点一起往上跳。

两个点一起跳,能跳的条件是:跳这个步长还在树上(fa[i][j]!=0fa[i][j]!= 0

一直跳到找到相同节点位置。跳的时候从20开始逐步缩小步长。

int lca(int x, int y) { //用倍增法求lca
	if (depth[x] < depth[y]) swap(x, y);
	for (int i = 20; i >= 0; i--) {
		if (fa[x][i] != 0 && depth[fa[x][i]] >= depth[y])
			x = fa[x][i];
	}
	if (x == y) return x;
	for (int i = 20; i >= 0; i--) {
		if (fa[x][i] != 0 && fa[y][i] != 0 && fa[x][i] != fa[y][i]) {
			x = fa[x][i];
			y = fa[y][i];
		}
	}
	return fa[x][0];
}

完整代码:

#include<iostream>
#include<vector>
#include<cstdio>
using namespace std;

const int maxn = 6e5;

int n, m;
int depth[maxn], fa[maxn][21];
bool vis[maxn];
vector<int> vec[maxn];

void dfs(int now, int dep, int father) {
	// 预处理深度
	if (vis[now]) return;
	vis[now] = true;
	depth[now] = dep;
	fa[now][0] = father;
	for (int v : vec[now]) {
		dfs(v, dep + 1, now);
	}
}

int lca(int x, int y) { //用倍增法求lca
	if (depth[x] < depth[y]) swap(x, y);
	for (int i = 20; i >= 0; i--) {
		if (fa[x][i] != 0 && depth[fa[x][i]] >= depth[y])
			x = fa[x][i];
	}
	if (x == y) return x;
	for (int i = 20; i >= 0; i--) {
		if (fa[x][i] != 0 && fa[y][i] != 0 && fa[x][i] != fa[y][i]) {
			x = fa[x][i];
			y = fa[y][i];
		}
	}
	return fa[x][0];

}

int main() {
	cin >> n >> m;
	for (int i = 1; i <= n-1; i++) {
		int a, b;
		scanf("%d %d", &a, &b);
		vec[a].push_back(b);
		vec[b].push_back(a);
	}
	dfs(1, 1, 0);
	for (int j = 1; j <= 20; j++)
		for (int i = 1; i <= n; i++)
			fa[i][j] = fa[fa[i][j - 1]][j - 1];
	for (int i = 1; i <= m; i++) {
		int a, b;
		scanf("%d %d", &a, &b);
		printf("%d\n", lca(a, b);
	}
}

最后更新于