线段树
2025/6/1大约 11 分钟
线段树
模板
普通线段树
支持单点修改、增加、查询,区间修改、增加、求和、最值查询功能
有注释
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10; // 原始数组大小
struct segTree{ // 基础线段树模板
// 操作:单点修改、增加、查询,区间修改、增加、求和、最值查询
// 所有操作时间复杂度均为 O(log n)
/*
线段树节点结构:
l, r: 节点表示的区间左右端点
sum: 区间和
max_val: 区间最大值
min_val: 区间最小值
add_tag: 加法懒标记,用于延迟更新,0表示没有待添加的值
set_tag: 赋值懒标记,用于延迟更新,无穷大表示未设置
need_set: 标记是否需要执行赋值操作
函数参数说明:
p: 当前节点索引
L, R: 当前节点表示的区间
pos: 要单点操作的位置
l, r: 要操作的区间
value: 要设置的值
delta: 要增加的值
proto[]: 原始数组
n: 原始数组长度
*/
struct node{
ll l, r, sum, max_val, min_val;
ll add_tag = 0;
ll set_tag = LLONG_MAX;
bool need_set = false;
} t[4*N];
// 标记下推函数
void pushdown(ll p, ll L, ll R){
ll mid = (L + R) / 2;
ll lc = p*2, rc = lc+1;
// 先处理设置标记(优先级更高)
if(t[p].need_set && L<R){
// 更新左子节点
t[lc].sum = t[p].set_tag * (mid - L + 1); // 区间和 = 值 * 区间长度
t[lc].max_val = t[p].set_tag; // 最大值等于设置值
t[lc].min_val = t[p].set_tag; // 最小值等于设置值
t[lc].set_tag = t[p].set_tag; // 传递设置标记
t[lc].need_set = true; // 标记子节点需要更新
t[lc].add_tag = 0; // 清除加法标记,因为赋值操作会覆盖之前的加法
// 更新右子节点
t[rc].sum = t[p].set_tag * (R - mid);
t[rc].max_val = t[p].set_tag;
t[rc].min_val = t[p].set_tag;
t[rc].set_tag = t[p].set_tag;
t[rc].need_set = true;
t[rc].add_tag = 0;
// 清除当前节点的设置标记
t[p].need_set = false;
t[p].set_tag = LLONG_MAX;
}
// 处理加法标记
if(t[p].add_tag && L < R){ // 当存在加法标记且不是叶子节点时
// 更新左子节点
t[lc].sum += t[p].add_tag * (mid - L + 1); // 区间和增加
t[lc].max_val += t[p].add_tag; // 最大值增加
t[lc].min_val += t[p].add_tag; // 最小值增加
t[lc].add_tag += t[p].add_tag; // 传递加法标记
// 更新右子节点
t[rc].sum += t[p].add_tag * (R - mid);
t[rc].max_val += t[p].add_tag;
t[rc].min_val += t[p].add_tag;
t[rc].add_tag += t[p].add_tag;
// 清除当前节点的加法标记
t[p].add_tag = 0;
}
}
// 向上更新父节点信息
void update_father(ll p){
t[p].sum = t[p*2].sum + t[p*2+1].sum; // 父节点和等于子节点和之和
t[p].max_val = max(t[p*2].max_val, t[p*2+1].max_val); // 父节点最大值为子节点最大值的较大值
t[p].min_val = min(t[p*2].min_val, t[p*2+1].min_val); // 父节点最小值为子节点最小值的较小值
}
// 线段树建树
void build(ll proto[], ll p, ll L, ll R){
t[p].l = L, t[p].r = R; // 记录当前节点表示的区间
if(L == R){ // 叶子节点,直接存储原始数组的值
t[p].sum = proto[L];
t[p].max_val = proto[L];
t[p].min_val = proto[L];
return;
}
ll mid = (L + R) >> 1; // 计算区间中点
// 递归构建左右子树
build(proto, p*2, L, mid);
build(proto, p*2+1, mid+1, R);
// 更新父节点信息
update_father(p);
}
// 单点修改
void point_update(ll p, ll pos, ll value, ll L, ll R){
if(L == R){ // 找到目标叶子节点
t[p].sum = value;
t[p].max_val = value;
t[p].min_val = value;
return;
}
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1;
// 根据位置决定递归左子树还是右子树
if(pos <= mid) point_update(p*2, pos, value, L, mid);
else point_update(p*2+1, pos, value, mid+1, R);
update_father(p); // 更新父节点信息
}
// 单点增加
void point_add(ll p, ll pos, ll delta, ll L, ll R){
if(L == R){ // 找到目标叶子节点
t[p].sum += delta;
t[p].max_val += delta;
t[p].min_val += delta;
return;
}
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1;
// 根据位置决定递归左子树还是右子树
if(pos <= mid) point_add(p*2, pos, delta, L, mid);
else point_add(p*2+1, pos, delta, mid+1, R);
update_father(p); // 更新父节点信息
}
// 单点查询
ll point_query(ll p, ll pos, ll L, ll R){
if(L == R) return t[p].sum; // 找到目标叶子节点
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1;
// 根据位置决定递归左子树还是右子树
if(pos <= mid) return point_query(p*2, pos, L, mid);
else return point_query(p*2+1, pos, mid+1, R);
}
// 区间修改
void range_set(ll p, ll l, ll r, ll value, ll L, ll R){
if(l <= L && R <= r){ // 当前节点区间完全包含在目标区间内
t[p].sum = value * (R - L + 1); // 更新区间和
t[p].max_val = value; // 更新最大值
t[p].min_val = value; // 更新最小值
t[p].set_tag = value; // 设置赋值标记
t[p].need_set = true; // 标记需要执行赋值操作
t[p].add_tag = 0; // 清除加法标记
return;
}
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1;
// 递归处理左右子树
if(l <= mid) range_set(p*2, l, r, value, L, mid);
if(r > mid) range_set(p*2+1, l, r, value, mid+1, R);
update_father(p); // 更新父节点信息
}
// 区间增加
void range_add(ll p, ll l, ll r, ll delta, ll L, ll R){
if(l <= L && R <= r){ // 当前节点区间完全包含在目标区间内
t[p].sum += delta * (R - L + 1); // 更新区间和
t[p].max_val += delta; // 更新最大值
t[p].min_val += delta; // 更新最小值
t[p].add_tag += delta; // 更新加法标记
return;
}
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1;
// 递归处理左右子树
if(l <= mid) range_add(p*2, l, r, delta, L, mid);
if(r > mid) range_add(p*2+1, l, r, delta, mid+1, R);
update_father(p); // 更新父节点信息
}
// 区间求和
ll range_sum(ll p, ll l, ll r, ll L, ll R){
if(l <= L && R <= r) return t[p].sum; // 当前节点区间完全包含在目标区间内
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1, sum = 0;
// 累加左右子树中与目标区间重叠部分的和
if(l <= mid) sum += range_sum(p*2, l, r, L, mid);
if(r > mid) sum += range_sum(p*2+1, l, r, mid+1, R);
return sum;
}
// 区间最大值查询
ll range_max(ll p, ll l, ll r, ll L, ll R){
if(l <= L && R <= r) return t[p].max_val; // 当前节点区间完全包含在目标区间内
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1, res = LLONG_MIN; // 初始化为最小值
// 取左右子树中与目标区间重叠部分的最大值
if(l <= mid) res = max(res, range_max(p*2, l, r, L, mid));
if(r > mid) res = max(res, range_max(p*2+1, l, r, mid+1, R));
return res;
}
// 区间最小值查询
ll range_min(ll p, ll l, ll r, ll L, ll R){
if(l <= L && R <= r) return t[p].min_val; // 当前节点区间完全包含在目标区间内
pushdown(p, L, R); // 下推标记
ll mid = (L + R) >> 1, res = LLONG_MAX; // 初始化为最大值
// 取左右子树中与目标区间重叠部分的最小值
if(l <= mid) res = min(res, range_min(p*2, l, r, L, mid));
if(r > mid) res = min(res, range_min(p*2+1, l, r, mid+1, R));
return res;
}
// 函数操作接口
void init(ll proto[], ll n){ build(proto, 1, 1, n); }
void Set(ll pos, ll value){ point_update(1, pos, value, t[1].l, t[1].r); }
void Add(ll pos, ll delta){ point_add(1, pos, delta, t[1].l, t[1].r); }
ll Get(ll pos){ return point_query(1, pos, t[1].l, t[1].r); }
void qset(ll l, ll r, ll value){ range_set(1, l, r, value, t[1].l, t[1].r); }
void qadd(ll l, ll r, ll delta){ range_add(1, l, r, delta, t[1].l, t[1].r); }
ll qsum(ll l, ll r){ return range_sum(1, l, r, t[1].l, t[1].r); }
ll qmax(ll l, ll r){ return range_max(1, l, r, t[1].l, t[1].r); }
ll qmin(ll l, ll r){ return range_min(1, l, r, t[1].l, t[1].r); }
};
segTree T;
ll a[N],n,m;
int main(){
cin>>n>>m;
for(ll i=1;i<=n;i++)cin>>a[i];
T.init(a,n);
while(m--){
ll op,x,y,k;
cin>>op>>x>>y;
if(op==1){
cin>>k;
T.qadd(x,y,k);
}
else cout<<T.qsum(x,y)<<'\n';
}
return 0;
}
无注释
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10;
struct segTree{
struct node{
ll l,r,sum,max_val,min_val,add_tag=0,set_tag=LLONG_MAX;
bool need_set=false;
}t[4*N];
void pushdown(ll p,ll L,ll R){
ll mid=(L+R)/2,lc=p*2,rc=lc+1;
if(t[p].need_set&&L<R){
t[lc].sum=t[p].set_tag*(mid-L+1);
t[lc].min_val=t[lc].set_tag=t[lc].max_val=t[p].set_tag;
t[lc].need_set=true,t[lc].add_tag=0;
t[rc].sum=t[p].set_tag*(R-mid);
t[rc].min_val=t[rc].set_tag=t[rc].max_val=t[p].set_tag;
t[rc].need_set=true,t[rc].add_tag=0;
t[p].need_set=false,t[p].set_tag=LLONG_MAX;
}
if(t[p].add_tag&&L<R){
t[lc].sum+=t[p].add_tag*(mid-L+1);
t[rc].sum+=t[p].add_tag*(R-mid);
t[lc].max_val+=t[p].add_tag,t[rc].max_val+=t[p].add_tag;
t[lc].min_val+=t[p].add_tag,t[rc].min_val+=t[p].add_tag;
t[lc].add_tag+=t[p].add_tag,t[rc].add_tag+=t[p].add_tag;
t[p].add_tag = 0;
}
}
void update_father(ll p){
t[p].sum=t[p*2].sum+t[p*2+1].sum;
t[p].max_val=max(t[p*2].max_val,t[p*2+1].max_val);
t[p].min_val=min(t[p*2].min_val,t[p*2+1].min_val);
}
void build(ll proto[],ll p,ll L,ll R){
t[p].l=L,t[p].r=R;
if(L==R){
t[p].sum=proto[L];
t[p].max_val=proto[L];
t[p].min_val=proto[L];
return;
}
ll mid=(L+R)>>1;
build(proto,p*2,L,mid);
build(proto,p*2+1,mid+1,R);
update_father(p);
}
void point_update(ll p,ll pos,ll value,ll L,ll R){
if(L==R){
t[p].sum=t[p].max_val=t[p].min_val=value;
return;
}
pushdown(p,L,R);
ll mid=(L+R)>>1;
if(pos<=mid)point_update(p*2,pos,value,L,mid);
else point_update(p*2+1,pos,value,mid+1,R);
update_father(p);
}
void point_add(ll p,ll pos,ll delta,ll L,ll R){
if(L==R){
t[p].sum+=delta;
t[p].max_val+=delta;
t[p].min_val+=delta;
return;
}
pushdown(p,L,R);
ll mid=(L+R)>>1;
if(pos<=mid)point_add(p*2,pos,delta,L,mid);
else point_add(p*2+1,pos,delta,mid+1,R);
update_father(p);
}
ll point_query(ll p,ll pos,ll L,ll R){
if(L==R)return t[p].sum;
pushdown(p,L,R);
ll mid=(L+R)>>1;
if(pos<=mid)return point_query(p*2,pos,L,mid);
else return point_query(p*2+1,pos,mid+1,R);
}
void range_set(ll p,ll l,ll r,ll value,ll L,ll R){
if(l<=L&&R<=r){
t[p].sum=value*(R-L+1);
t[p].max_val=t[p].min_val=t[p].set_tag=value;
t[p].need_set=true,t[p].add_tag = 0;
return;
}
pushdown(p,L,R);
ll mid=(L+R)>>1;
if(l<=mid)range_set(p*2,l,r,value,L,mid);
if(r>mid)range_set(p*2+1,l,r,value,mid+1,R);
update_father(p);
}
void range_add(ll p,ll l,ll r,ll delta,ll L,ll R){
if(l<=L&&R<=r){
t[p].sum+=delta*(R-L+1);
t[p].max_val+=delta;
t[p].min_val+=delta;
t[p].add_tag+=delta;
return;
}
pushdown(p, L, R);
ll mid=(L+R)>>1;
if(l<=mid)range_add(p*2,l,r,delta,L,mid);
if(r>mid)range_add(p*2+1,l,r,delta,mid+1,R);
update_father(p);
}
ll range_sum(ll p,ll l,ll r,ll L,ll R){
if(l<=L&&R<=r)return t[p].sum;
pushdown(p,L,R);
ll mid=(L+R)>>1,sum=0;
if(l<=mid)sum+=range_sum(p*2,l,r,L,mid);
if(r>mid)sum+=range_sum(p*2+1,l,r,mid+1,R);
return sum;
}
ll range_max(ll p,ll l,ll r,ll L,ll R){
if(l<=L&&R<=r)return t[p].max_val;
pushdown(p,L,R);
ll mid=(L+R)>>1,res=LLONG_MIN;
if(l<=mid)res=max(res,range_max(p*2,l,r,L,mid));
if(r>mid)res=max(res,range_max(p*2+1,l,r,mid+1,R));
return res;
}
ll range_min(ll p,ll l,ll r,ll L,ll R){
if(l<=L&&R<=r)return t[p].min_val;
pushdown(p,L,R);
ll mid=(L+R)>>1,res=LLONG_MAX;
if(l<=mid)res=min(res,range_min(p*2,l,r,L,mid));
if(r>mid)res=min(res,range_min(p*2+1,l,r,mid+1,R));
return res;
}
void init(ll proto[],ll n){build(proto,1,1,n);}
void Set(ll pos,ll value){point_update(1,pos,value,t[1].l,t[1].r);}
void Add(ll pos,ll delta){point_add(1,pos,delta,t[1].l,t[1].r);}
ll Get(ll pos){ return point_query(1,pos,t[1].l,t[1].r);}
void qset(ll l,ll r,ll value){ range_set(1,l,r,value,t[1].l,t[1].r);}
void qadd(ll l,ll r,ll delta){ range_add(1,l,r,delta,t[1].l,t[1].r);}
ll qsum(ll l,ll r){ return range_sum(1,l,r,t[1].l,t[1].r);}
ll qmax(ll l,ll r){ return range_max(1,l,r,t[1].l,t[1].r);}
ll qmin(ll l,ll r){ return range_min(1,l,r,t[1].l,t[1].r);}
};
segTree T;
ll a[N],n,m;
int main(){
cin>>n>>m;
for(ll i=1;i<=n;i++)cin>>a[i];
T.init(a,n);
while(m--){
ll op,x,y,k;
cin>>op>>x>>y;
if(op==1){
cin>>k;
T.qadd(x,y,k);
}
else cout<<T.qsum(x,y)<<'\n';
}
return 0;
}
ZKW 线段树
码量:约为原来的
时间:约为原来的
空间:约为原来的
typedef long long ll;
const int N=1e5+10;
ll tr[N*3],sum[N*3],n,m,P=1;
void add(int l,int r,ll k){
int siz=1;
for(l=P+l-1,r=P+r+1;l^1^r;){
if(~l&1)tr[l^1]+=siz*k,sum[l^1]+=k;
if(r&1)tr[r^1]+=siz*k,sum[r^1]+=k;
l>>=1,r>>=1,siz<<=1;
tr[l]=tr[l<<1]+tr[l<<1|1]+sum[l]*siz;
tr[r]=tr[r<<1]+tr[r<<1|1]+sum[r]*siz;
}
for(l>>=1,siz<<=1;l;l>>=1,siz<<=1)tr[l]=tr[l<<1]+tr[l<<1|1]+sum[l]*siz;
}
ll query(int l,int r){
ll res=0,sizl=0,sizr=0,siz=1;
for(l=l+P-1,r=r+P+1;l^1^r;){
if(~l&1)res+=tr[l^1],sizl+=siz;
if(r&1)res+=tr[r^1],sizr+=siz;
l>>=1,r>>=1,siz<<=1;
res+=sum[l]*sizl+sum[r]*sizr;
}
for(l>>=1,sizl+=sizr;l;l>>=1)res+=sum[l]*sizl;
return res;
}