快乐树0x02 线段树实现(c++)


一、线段树的基本思想

线段树?

线段树是一种用来维护_区间信息_ 的数据结构。比如需要对一段连续的区间进行修改、查询等大量操作,使用线段树来维护相比线性表能取得更大优势。

线段树的时间复杂度为 O(logN)O(logN)级。

基本思路

线段树的基本思想是将每一个区间长度>1的部分划分成左右两个区间进行递归求解,在一层层划分时将整个线段划分成一个树形结构。这样就可以合并欲求区间包含的子树的信息来求得所求区间包含的全部信息。

这个树形结构是一棵二叉树,因此可以用二叉树的性质,求得:

某区间 pp 的左子区间(左儿子)的节点号为 p2p_2,右子区间(右儿子)节点号为 p2+1p_2+1

Lazy tag

由于线段树要支持对维护数据中的任意一个区间进行修改,而线段树是一个树形结构,如果每次修改都要对涉及的所有节点都去修改,其修改效率反而不如线性表。实际上,对欲修改区间,如果线段树某节点的区间正好全部在欲修改区间中,那么只需要修改该节点的值就可以了,不需要再深入到子节点中进行修改,因为至少目前用不到。对该节点进行修改后,我们给该节点打上一个标记,表明我们这次修改的数量,如果以后需要获取其子节点中的值,我们再把之前这次修改进行落实。这样,如果以后不需要其子节点的值,我们就省下了修改其子节点 的时间。

这个lazy tag的方法有点像初学的时候做的一道铺地毯的题,相比去修改每个被地毯覆盖的点,记录每次铺地毯的范围并与欲求坐标进行比对效率更高。

二、代码实现

建立线段树

递归分割建树。树的节点总数大致是原数集大小的4倍,开4*maxn。

int d[maxn * 4];//存树
int b[maxn * 4];//lazytag标记
int a[maxn];//原数据集

void build(int s, int t, int p) {
	//建立线段树,对[s,t]建树,当前根节点编号为p
	if (s == t) { //找到单个数了
		d[p] = a[s];
		return;
	}
	int m = s + ((t - s) / 2); //求中间值,二分递归
	build(s, m, p * 2);//p*2是左子节点
	build(m + 1, t, p * 2+1);// 右子节点
	d[p] = d[p * 2] + d[p * 2 + 1];//更新节点
}

getsum函数

写一个函数,取得区间和。依然是递归,递归边界是当前区间全部在所求区间内,就返回当前节点值。

对于遇见的lazy tag,将tag记录的修改作用到子节点上,然后把标记下放到子节点中。注意,所有被标记的节点是已经被作用了修改的,而其子节点还没有被作用修改。

int getsum(int l, int r, int s, int t, int p) {
	//取得区间和,l、r是最终目标范围,s、t是当前递归范围,p是当前范围的节点编号
	if (l <= s && t <= r) {
		//如果当前区间全部在目标范围内,直接返回
		return d[p];
	}
	int m = s + (t - s) / 2; //拆分,递归
	if (b[p]) { //如果有标记,往子节点访问要更新
		d[p * 2] += b[p] * (m - s + 1);
		d[p * 2 + 1] += b[p] * (t - m);
		b[p * 2] += b[p];
		b[p * 2 + 1] += b[p];
		b[p] = 0;
	}

	int sum = 0;
	if (l <= m) {
		//左半边有目标范围
		sum += getsum(l, r, s, m, p * 2);
	}
	if (r > m) {
		//右半边(注意右半边不包括m)有目标范围
		sum += getsum(l, r, m + 1, t, p * 2 + 1);
	}
	return sum;
}

add函数

用于区间修改(加或减)。使用lazy tag:在发现当前节点完全包含在目标节点时,就没有必要再修改子节点了,直接给本节点的值修改 修改值×本节点表示的长度修改值×本节点表示的长度 即可,同时在本节点打下标记。同样,在遇到标记的时候,要先更新子节点并下沉标记到子节点,然后再拆分递归。当然,没有必要处理叶子节点的标记,因为他们没有子节点。

void add(int l, int r, int c, int s, int t, int p) {
	//区间加(减),l,r是最终目标区间,c是增加的值,可以为负。s,t是当前区间,p是当前区间节点编号
	if (l <= s && t <= r) { //目标全包含当前区间,计算后返回
		d[p] += (t - s + 1) * c;
		b[p] += c; //标记好
		return;
	}
	//具体要往子节点访问,更新子节点并消除标记
	int m = s + ((t - s) / 2);
	if (b[p] && s != t) { //非叶子带标记,更新子节点值并下沉标记
		d[p * 2] += b[p] * (m - s + 1);//总和是修改了b[p]*num的
		d[p * 2 + 1] += b[p] * (t - m);
		b[p * 2] += b[p]; //标记下沉
		b[p * 2 + 1] += b[p];
		b[p] = 0;
		
	}
	if (l <= m) add(l, r, c, s, m, p * 2);
	if (r > m) add(l, r, c, m + 1, t, p * 2 + 1);
	d[p] = d[p * 2] + d[p * 2 + 1];
}

调用示例

注意,建树时节点编号必须>0,否则n*0都是0,整个树都会发生错误。

int main() {
	for (int i = 1; i <= 10; i++) { a[i] = 1;}
	build(1, 10, 1);
	add(1, 4, 0, 1, 10, 1);
	printf("%d", getsum(1, 10, 1, 10, 1));
 }

提供区间加减法、区间求和的完整代码模板:

#include<cstdio>
using namespace std;
#define maxn 10100
int d[maxn * 4];//存树
int b[maxn * 4];//lazytag标记
int a[maxn];//原数据集

void build(int s, int t, int p) {
	//建立线段树,对[s,t]建树,当前根节点编号为p
	if (s == t) { //找到单个数了
		d[p] = a[s];
		return;
	}
	int m = s + ((t - s) / 2); //求中间值,二分递归
	build(s, m, p * 2);//p*2是左子节点
	build(m + 1, t, p * 2+1);// 右子节点
	d[p] = d[p * 2] + d[p * 2 + 1];//更新节点
}

int getsum(int l, int r, int s, int t, int p) {
	//取得区间和,l、r是最终目标范围,s、t是当前递归范围,p是当前范围的节点编号
	if (l <= s && t <= r) {
		//如果当前区间全部在目标范围内,直接返回
		return d[p];
	}
	int m = s + (t - s) / 2; //拆分,递归
	if (b[p]) { //如果有标记,往子节点访问要更新
		d[p * 2] += b[p] * (m - s + 1);
		d[p * 2 + 1] += b[p] * (t - m);
		b[p * 2] += b[p];
		b[p * 2 + 1] += b[p];
		b[p] = 0;
	}

	int sum = 0;
	if (l <= m) {
		//左半边有目标范围
		sum += getsum(l, r, s, m, p * 2);
	}
	if (r > m) {
		//右半边(注意右半边不包括m)有目标范围
		sum += getsum(l, r, m + 1, t, p * 2 + 1);
	}
	return sum;
}

void add(int l, int r, int c, int s, int t, int p) {
	//区间加(减),l,r是最终目标区间,c是增加的值,可以为负。s,t是当前区间,p是当前区间节点编号
	if (l <= s && t <= r) { //目标全包含当前区间,计算后返回
		d[p] += (t - s + 1) * c;
		b[p] += c; //标记好
		return;
	}
	//具体要往子节点访问,更新子节点并消除标记
	int m = s + ((t - s) / 2);
	if (b[p] && s != t) { //非叶子带标记,更新子节点值并下沉标记
		d[p * 2] += b[p] * (m - s + 1);//总和是修改了b[p]*num的
		d[p * 2 + 1] += b[p] * (t - m);
		b[p * 2] += b[p]; //标记下沉
		b[p * 2 + 1] += b[p];
		b[p] = 0;
		
	}
	if (l <= m) add(l, r, c, s, m, p * 2);
	if (r > m) add(l, r, c, m + 1, t, p * 2 + 1);
	d[p] = d[p * 2] + d[p * 2 + 1];
}


int main() {
	for (int i = 1; i <= 10; i++) { a[i] = 1; b[i] = 0; }
	build(1, 10, 1);
	add(1, 4, 0, 1, 10, 1);
	printf("%d", getsum(1, 10, 1, 10, 1));
 }

提供区间加减、乘法、区间求和的代码模板:

#include<cstdio>
using namespace std;
#define maxn 10100
int d[maxn * 4];//存树
int b[maxn * 4];//lazytag标记
int b2[maxn * 4];//标记2
int a[maxn];//原数据集

void build(int s, int t, int p) {
	//建立线段树,对[s,t]建树,当前根节点编号为p
	if (s == t) { //找到单个数了
		d[p] = a[s];
		return;
	}
	int m = s + ((t - s) / 2); //求中间值,二分递归
	build(s, m, p * 2);//p*2是左子节点
	build(m + 1, t, p * 2+1);// 右子节点
	d[p] = d[p * 2] + d[p * 2 + 1];//更新节点
	d[p] %= 998244353;
}

int getsum(int l, int r, int s, int t, int p) {
	//取得区间和,l、r是最终目标范围,s、t是当前递归范围,p是当前范围的节点编号
	if (l <= s && t <= r) {
		//如果当前区间全部在目标范围内,直接返回
		return d[p];
	}
	int m = s + (t - s) / 2; //拆分,递归
	if (b[p]) { //如果有标记,往子节点访问要更新
		d[p * 2] += b[p] * (m - s + 1);
		d[p*2] %= 998244353;
		d[p * 2 + 1] += b[p] * (t - m);
		d[p * 2+1] %= 998244353;
		b[p * 2] += b[p];
		b[p * 2 + 1] += b[p];
		b[p] = 1;
	}

	if (b2[p]>1) { //如果有标记2,往子节点访问要更新
		d[p * 2] *= b2[p];
		d[p * 2] %= 998244353;
		d[p * 2 + 1] *= b2[p];
		d[p * 2 + 1] %= 998244353;
		b2[p * 2] *= b2[p];
		b2[p * 2 + 1] *= b2[p];
		b2[p] = 0;
	}

	int sum = 0;
	if (l <= m) {
		//左半边有目标范围
		sum += getsum(l, r, s, m, p * 2);
		sum %=998244353;
	}
	if (r > m) {
		//右半边(注意右半边不包括m)有目标范围
		sum += getsum(l, r, m + 1, t, p * 2 + 1);
		sum %= 998244353;
	}
	return sum;
}

void add(int l, int r, int c, int s, int t, int p) {
	//区间加(减),l,r是最终目标区间,c是增加的值,可以为负。s,t是当前区间,p是当前区间节点编号
	if (l <= s && t <= r) { //目标全包含当前区间,计算后返回
		d[p] += (t - s + 1) * c;
		d[p]%= 998244353;
		b[p] += c; //标记好
		return;
	}
	//具体要往子节点访问,更新子节点并消除标记
	int m = s + ((t - s) / 2);
	if (b[p] && s != t) { //非叶子带标记,更新子节点值并下沉标记
		d[p * 2] += b[p] * (m - s + 1);//总和是修改了b[p]*num的
		d[p * 2 + 1] += b[p] * (t - m);
		d[p * 2] %= 998244353;
		d[p * 2 + 1] %= 998244353;
		b[p * 2] += b[p]; //标记下沉
		b[p * 2 + 1] += b[p];
		b[p] = 0;
		
	}
	if (l <= m) add(l, r, c, s, m, p * 2);
	if (r > m) add(l, r, c, m + 1, t, p * 2 + 1);
	d[p] = (d[p * 2] + d[p * 2 + 1]) % 998244353;
}

void mul(int l, int r, int c, int s, int t, int p) {
	//区间乘法,l,r是最终目标区间,c是乘的值,可以为负。s,t是当前区间,p是当前区间节点编号
	if (l <= s && t <= r) { //目标全包含当前区间,计算后返回
		d[p] *= c;
		b2[p] *= c; //标记好
		d[p] %= 998244353;
		return;
	}
	//具体要往子节点访问,更新子节点并消除标记
	int m = s + ((t - s) / 2);
	if (b2[p]>1 && s != t) { //非叶子带标记,更新子节点值并下沉标记
		d[p * 2] *= b2[p];//这里是乘积,直接乘就行了
		d[p * 2 + 1] *= b2[p];
		d[p * 2] %= 998244353;
		d[p * 2 + 1] %= 998244353;
		b2[p * 2] *= b2[p]; //标记下沉
		b2[p * 2 + 1] *= b2[p];
		b2[p] = 1;

	}
	if (l <= m) mul(l, r, c, s, m, p * 2);
	if (r > m) mul(l, r, c, m + 1, t, p * 2 + 1);
	d[p] = (d[p * 2] + d[p * 2 + 1]) % 998244353;
}

int main() {
	for (int i = 1; i <= 10; i++) { a[i] = 1; b[i] = 0; b2[i] = 1; }
	build(1, 10, 1);
	mul(2, 6, 2, 1, 10, 1);
	printf("%d", getsum(1, 10, 1, 10, 1));
 }

最后更新于