树链剖分(线段树的应用)

    xiaoxiao2022-07-05  151

    【题目描述】

    原题来自:ZJOI 2008

    一树上有 n 个节点,编号分别为 1 到 n,每个节点都有一个权值 w。我们将以下面的形式来要求你对这棵树完成一些操作:

    1.CHANGE u t :把节点 u 权值改为 t;

    2.QMAX u v :询问点 u 到点 v 路径上的节点的最大权值;

    3.QSUM u v :询问点 u 到点 v 路径上的节点的权值和。

    注意:从点 u 到点 v 路径上的节点包括 u 和 v 本身。

    【输入】

    第一行为一个数 n,表示节点个数;

    接下来 n−1 行,每行两个整数 a,b,表示节点 a 与节点 b 之间有一条边相连;

    接下来 n 行,每行一个整数,第 i 行的整数 wi 表示节点 i 的权值;

    接下来一行,为一个整数 q ,表示操作总数;

    接下来 q 行,每行一个操作,以 CHANGE u t 或 QMAX u v 或 QSUM u v的形式给出。

    【输出】

    对于每个 QMAX 或 QSUM 的操作,每行输出一个整数表示要求的结果。

    【输入样例】

    4 1 2 2 3 4 1 4 2 1 3 12 QMAX 3 4 QMAX 3 3 QMAX 3 2 QMAX 2 3 QSUM 3 4 QSUM 2 1 CHANGE 1 5 QMAX 3 4 CHANGE 3 6 QMAX 3 4 QMAX 2 4 QSUM 3 4

    【输出样例】

    4 1 2 2 10 6 5 6 5 16

    【提示】

    数据范围与提示:

    对于 100% 的数据,有 1≤n≤3×104,0≤q≤2×105 。中途操作中保证每个节点的权值 w 在 −30000 至 30000 之间。

    #include<bits/stdc++.h>//xyc大佬讲解的程序 #define lc (k<<1) #define rc (k<<1|1) #define mid ((a[k].l+a[k].r)>>1) using namespace std; inline void in(int &x){ int f=1;x=0;char w=getchar(); while(w<'0'||w>'9'){if(w=='-') f=-f;w=getchar();} while(w>='0'&&w<='9'){x=(x<<3)+(x<<1)+(w^48);w=getchar();} x*=f; } const int N=3e4+10; struct node{ int son,fa,deep,size,top,zhi,id; }p[N]; struct tree{ int l,r,val,maxn; }a[N<<2]; int n,m,x,y,tot,cnt;char op[100]; int fir[N],vis[N],rk[N],nxt[N<<1],ver[N<<1]; inline void add(int x,int y){ver[++tot]=y,nxt[tot]=fir[x],fir[x]=tot;} void dfs1(int x){ vis[x]=1,p[x].size=1,p[x].deep=p[p[x].fa].deep+1;int maxn=0; for(int i=fir[x];i;i=nxt[i]){ int y=ver[i];if(vis[y]) continue; p[y].fa=x,dfs1(y);p[x].size+=p[y].size; if(p[y].size>maxn) maxn=p[y].size,p[x].son=y; } } void dfs2(int x){ p[x].id=++cnt,rk[cnt]=x,p[x].top=x==p[p[x].fa].son?p[p[x].fa].top:x; if(!p[x].son) return ;dfs2(p[x].son); for(int i=fir[x];i;i=nxt[i]) {int y=ver[i];if(!p[y].id) dfs2(y);} } void build(int k,int l,int r){ a[k].l=l,a[k].r=r; if(l==r) {a[k].maxn=a[k].val=p[rk[l]].zhi;return;} build(lc,l,mid),build(rc,mid+1,r); a[k].maxn=max(a[lc].maxn,a[rc].maxn); a[k].val=a[lc].val+a[rc].val; } void add(int k,int x,int z){ if(a[k].l==a[k].r){a[k].maxn=a[k].val=z;return;} if(x<=mid) add(lc,x,z);else add(rc,x,z); a[k].maxn=max(a[lc].maxn,a[rc].maxn); a[k].val=a[lc].val+a[rc].val; } int asksum(int k,int l,int r){ if(a[k].l>=l&&a[k].r<=r)return a[k].val; int ans=0; if(l<=mid) ans+=asksum(lc,l,min(r,mid)); if(r>mid) ans+=asksum(rc,max(mid+1,l),r); return ans; } int askmax(int k,int l,int r){ if(a[k].l>=l&&a[k].r<=r)return a[k].maxn; int ans=-N; if(l<=mid) ans=max(ans,askmax(lc,l,min(mid,r))); if(r>mid) ans=max(ans,askmax(rc,max(mid+1,l),r)); return ans; } int main(){ in(n);for(int i=1;i<n;i++) in(x),in(y),add(x,y),add(y,x); for(int i=1;i<=n;i++) in(p[i].zhi); dfs1(1),dfs2(1),build(1,1,n),in(m); for(int i=1;i<=m;i++){ scanf("%s",op),in(x),in(y); if(op[3]=='X'){ int ans=-N; while(p[x].top!=p[y].top) if(p[p[x].top].deep>p[p[y].top].deep) ans=max(ans,askmax(1,p[p[x].top].id,p[x].id)),x=p[p[x].top].fa; else ans=max(ans,askmax(1,p[p[y].top].id,p[y].id)),y=p[p[y].top].fa; ans=max(ans,askmax(1,min(p[x].id,p[y].id),max(p[x].id,p[y].id))); printf("%d\n",ans); } else if(op[3]=='M'){ int ans=0; while(p[x].top!=p[y].top) if(p[p[x].top].deep>p[p[y].top].deep) ans+=asksum(1,p[p[x].top].id,p[x].id),x=p[p[x].top].fa; else ans+=asksum(1,p[p[y].top].id,p[y].id),y=p[p[y].top].fa; ans+=asksum(1,min(p[x].id,p[y].id),max(p[x].id,p[y].id)); printf("%d\n",ans); } else add(1,p[x].id,y); } return 0; } #include<cstdio>//另一位大佬的注释程序 #include<cstring> using namespace std; const int N=31000; const int M=124000; int n,q,k=1,first[N],summ,maxmax; struct Edge{ int v,next;}edge[M]; int num[N]; int father[N],dep[N],size[N],son[N],top[N],seg[N],rev[N]; int maxn[M],sum[M]; int read(){ int s=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){ if(ch=='-') f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){ s=(s<<3)+(s<<1)+ch-48;ch=getchar();} return f*s; } void addedge(int ui,int vi){ edge[++k].v=vi;edge[k].next=first[ui];first[ui]=k; edge[++k].v=ui;edge[k].next=first[vi];first[vi]=k; } void dfs1(int u,int fa){ //第一遍dfs size[u]=1; //u的结点数为1,包含它自己 father[u]=fa; //u的父亲为fa dep[u]=dep[fa]+1; //u的深度为其父亲fa的深度+1 for(int i=first[u];i;i=edge[i].next){ //穷举与u开头的所有边 int v1=edge[i].v; if(v1==fa) continue; dfs1(v1,u); size[u]+=size[v1]; //u包含的结点数累加儿子v1的结点 if(size[v1]>size[son[u]]) //如果v1的结点数>u的重儿子son[u]的结点数 son[u]=v1; //更新重儿子 } } void dfs2(int u,int fa){ //第二遍dfs if(son[u]){ //先走重儿子,保证重路径在线段树上的位置是连续的 seg[son[u]]=++seg[0]; //重儿子son[u]在线段树上的编号为++seg[0] top[son[u]]=top[u]; //重儿子son[u]所在的重路径的顶端结点为其父亲u所在的顶端结点 rev[seg[0]]=son[u]; //线段树上编号为seg[0]的结点号为son[u] dfs2(son[u],u); } for(int i=first[u];i;i=edge[i].next){ int v1=edge[i].v; if(top[v1]) continue; //如果v1已访问过, 即为重儿子或父亲,则不需要再访问 seg[v1]=++seg[0]; rev[seg[0]]=v1; top[v1]=v1; //若(u,v1)为轻边,则v1就是其所在重路径的顶部结点 dfs2(v1,u); } } int max(int x,int y){ if(x>y) return x;return y;} void build(int k,int l,int r){ //建立线段树 int mid=(l+r)>>1; if(l==r){ maxn[k]=sum[k]=num[rev[l]]; return; } build(k<<1,l,mid); build((k<<1)+1,mid+1,r); sum[k]=sum[k<<1]+sum[(k<<1)+1]; maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]); } void change(int k,int l,int r,int val,int pos){ //修改pos结点,改值为val if(pos<l||r<pos) return; //如果不在范围内,就退出 if(l==r&&l==pos){ //如果找到点,则更改 sum[k]=val; maxn[k]=val; return; } int mid=(l+r)>>1; if(mid>=pos) change(k<<1,l,mid,val,pos); if(mid<pos) change((k<<1)+1,mid+1,r,val,pos); sum[k]=sum[k<<1]+sum[(k<<1)+1]; maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]); } void swap(int &x,int &y){ int temp=x;x=y;y=temp;} void query(int k,int l,int r,int x,int y){ if(y<l||x>r) return; if(x<=l&&r<=y){ summ+=sum[k]; maxmax=max(maxmax,maxn[k]); return; } int mid=(l+r)>>1; if(x<=mid) query(k<<1,l,mid,x,y); if(mid<y) query((k<<1)+1,mid+1,r,x,y); } void ask(int x,int y){ //路径询问 int fx=top[x],fy=top[y]; //找到xy分别所在重路径的顶端结点 while(fx!=fy){ //退出时xy在同一条重路径上 if(dep[fx]<dep[fy]){ //保证x的深度更大 swap(x,y);swap(fx,fy); } query(1,1,seg[0],seg[fx],seg[x]); x=father[fx];fx=top[x]; } if(dep[x]>dep[y]) swap(x,y); //保证x在线段树上的编号<y query(1,1,seg[0],seg[x],seg[y]); } int main(){ n=read(); for(int i=1;i<n;++i) addedge(read(),read()); for(int i=1;i<=n;++i) num[i]=read(); dfs1(1,0); seg[0]=1; //记录截止到目前线段树的编号 seg[1]=1; //根结点0和1在线段树中的位置为1 top[1]=1; //1所在的重路径顶部结点为1 rev[1]=1; //线段树中第1个位置对应的结点还是1 dfs2(1,0); build(1,1,seg[0]); //建立线段树 q=read(); while(q--){ char st[10]; scanf("%s",st); int ui=read(),vi=read(); if(st[0]=='C') change(1,1,seg[0],vi,seg[ui]); else{ summ=0; maxmax=-100000000; ask(ui,vi); if(st[1]=='M') printf("%d\n",maxmax); else printf("%d\n",summ); } } return 0; }

    一次小尝试:

    #include<cstdio> #include<cstring> using namespace std; const int N=30005; struct node{ int v,next; }e[N<<1]; int n,y,x,q,w[N],first[N],k=0,cnt=1; int sum[N<<2],maxn[N<<2]; char st[10]; int sum1,maxn1 int dep[N],fa[N],size[N],son[N],top[N],seg[N]/*表示图中x点在线段树中的编号,即第二次深搜时的时间戳*/,rev[N]; void add(int x,int y){ e[++k].v=y;e[k].next=first[x];first[x]=k; } void dfs1(int u,int fat){ dep[u]=dep[fa]+1; size[u]=1; for(i=first[u];i;i=e[i].next){ int vi=e[i].v; if(vi==fat) continue; fa[vi]=u; dfs1(vi,u); size[u]+=size[vi]; if(size[vi]>size[son[u]]) son[u]=vi; } } void dfs2(int u,int fat){ if(son[u]){ seg[son[u]]=++cnt; rev[cnt]=son[u]; top[son[u]]=top[u]; dfs2(son[u],u); } for(int i=first[u];i;i=e[i].next){//对轻儿子的操作//即该儿子top为他本身 int vi=e[i].v; if(vi==fat)continue; if(!top[vi]){//表示vi没有访问过 seg[vi]=++cnt; rev[cnt]=vi; top[vi]=vi; dfs2(vi,u); } } } int max(int x,int y){return x > y ? x : y;} void build(int k,int l,int r){//建立线段树,计算对应的和、最大值 if(l==r){ =w[rev[l]]; return; } int mid=(l+r)>>1; build(k<<1,l,mid); built((k<<1)+1,mid+1,r); sum[k]=sum[k<<1]+sum[(k<<1)+1]; maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]); } void change(int k,int l,int r,int u,int t){ if(u<l||r<u) return; if(l==r&&l==u){ maxn[k]=sum[k]=t; return; } int mid=(l+r)<<1; if(u<=mid) change(k<<1,l,mid,u,t); else change((k<<1)+1,mid+1,r,u,t); sum[k]=sum[k<<1]+sum[(k<<1)+1]; maxn[k]=max(maxn[k<<1],maxn[(k<<1)+1]); } void query(int k,int l,int r,int x,int y){ if(y<l||x>r)return; if(l<=x&&y<=r){ sum1+=sum[k]; maxn1=max(maxn1,maxn[k]); return; } int mid=(l+r)<<1; if(x<=mid)query(k<<1,l,mid,x,y); if(mid+1<=y)query((k<<1)+1,mid+1,r,x,y) } int swap(int &x,int &y){int temp=x;x=y;y=temp;} void ask(int x,int y){ int fx=top[x],fy=top[y]; while(fx!=fy){ if(dep[fx]<dep[fy]){ swap(x,y),swap(fa,fy); } query(1,1,n,seg[fx],seg[x]); x=fa[fx]; fx=top[x]; } if(dep[]) } int main(){ scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%d%d",&x,&y); ad(x,y),add(y,x); } for(int i=1;i<=n;i++)scanf("%d",&w[i]); dfs1(1,0); seg[1]=cnt;rev[1]=1;top[1]=1; dfs2(1,0); build(1,1,n); scanf("%d",&q); while(q--){ scanf("%s%d%d",st,&x,&y); if(st[0]=='C') change(1,1,n,seg[x],y); else{ sum1=0;maxn1=-3000000; ask(x,y); if(st[1]=='M') printf("%d",maxn1); else printf("%d",sum1); } } return 0; }

    又一个标程:

    #include<cstdio> #include<cstring> #define lc (k<<1) #define rc (k<<1|1) using namespace std; const int N=3e4+5; struct node{ int vi,next;}edge[N<<1]; int n,k,x,y,cnt=0,q,max1,sum1; int son[N],fa[N],dep[N],size[N],top[N],id[N],rev[N]; int first[N],w[N],maxn[N<<2],sum[N<<2]; int read(){ //快速读入 int s=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){ if(ch=='-') f=-1;ch=getchar(); } while(ch>='0'&&ch<='9'){ s=(s<<3)+(s<<1)+ch-48;ch=getchar(); } return s*f; } void ADD(int x,int y){ edge[++k].vi=y;edge[k].next=first[x];first[x]=k;} //邻接表存储树 void dfs1(int u,int father){ //第一次深搜 fa[u]=father; //u的父亲是father dep[u]=dep[father]+1; //u的深度为父亲的深度加1 size[u]=1; //以u为根的子树结点数初始值为1,因为自己也算一个结点 for(int i=first[u];i;i=edge[i].next){ //访问u的所有儿子 int v=edge[i].vi; if(v==father) continue; //如果儿子为父亲,则不访问,因为是无向图 dfs1(v,u); //深搜儿子 size[u]+=size[v]; //u的结点数加上其儿子v的结点数 if(size[v]>size[son[u]]) son[u]=v; //如果v的结点数>u的重儿子的结点数,则更新重儿子 } } void dfs2(int u){ //第二次深搜 id[u]=++cnt; //u的时间戳是cnt,即在线段树上的编号 rev[cnt]=u; //在线段树上编号位cnt的,对应树上的结点号是u if(son[u]){ //如果u有重儿子 top[son[u]]=top[u]; //则u的重儿子son[u]所在重路径的深度最小的顶点=u的重路径顶点 dfs2(son[u]); //深搜重儿子 } for(int i=first[u];i;i=edge[i].next){ //访问u的所有儿子 int v=edge[i].vi; if(id[v]) continue; //如果儿子v访问过,则不需要再次访问,比如重儿子 top[v]=v; //不是重儿子的v的top值为自己 dfs2(v); //深搜儿子 } } int max(int x,int y){ if(x>y) return x;return y;} void build(int k,int l,int r){ //建立线段树 if(l==r){ //如果是叶子结点 maxn[k]=sum[k]=w[rev[l]]; return; } int mid=(l+r)>>1; build(lc,l,mid);build(rc,mid+1,r); maxn[k]=max(maxn[lc],maxn[rc]); sum[k]=sum[lc]+sum[rc]; } void change(int k,int l,int r,int u,int val){ //将结点u的值改为val if(u<l||r<u) return; //如果u不在区间[l,r]中,则退出 if(l==r&&l==u){ //如果找到u,则更新 maxn[k]=sum[k]=val; return; } int mid=(l+r)>>1; if(u<=mid) change(lc,l,mid,u,val); //如果u在左子树中,则更改左子树 else change(rc,mid+1,r,u,val); //否则更改右子树 maxn[k]=max(maxn[lc],maxn[rc]); //更新区间[l,r]的最大值 sum[k]=sum[lc]+sum[rc]; //更新区间[l,r]的和 } void swap(int &x,int &y){int temp=x;x=y;y=temp;} void query(int k,int l,int r,int x,int y){ //询问线段上[x,y]区间上的值 if(y<l||r<x) return; //如果区间[l,r]和[x,y]无交集,则退出 if(x<=l&&r<=y){ //如果区间[x,y]包含[l,r],则更新max1,sum1 max1=max(max1,maxn[k]); sum1+=sum[k]; return; } int mid=(l+r)>>1; if(mid>=x) query(lc,l,mid,x,y); //如果区间[x,y]与[l,mid]有交集,则查找左子树 if(mid+1<=y) query(rc,mid+1,r,x,y); //如果区间[x,y]与[mid+1,r]有交集,则查找右子树 } void ask(int u,int v){ //询问树上结点u,v之间的值 while(top[u]!=top[v]){ //如果u,v不在一个重路径上 if(dep[top[u]]<dep[top[v]]){ swap(u,v); } //保证top[u]的深度大于top[v],否则交换两者 query(1,1,n,id[top[u]],id[u]); //因为后面u要跳到top[u]的父亲处,故要将u~top[u]之间的路径上的值更新,其在线段树上的编号是连续的,即id[u ~id[top[u]] u=fa[top[u]]; //u跳到其top[u]的父亲处 } if(dep[u]<dep[v]) swap(u,v); //while结束时,u和v必然在同一个重路径上,if是为了保证u的深度大于v query(1,1,n,id[v],id[u]); //结点v~u在线段树上的编号是连续的,故查找id[v]~id[u] } int main(){ //freopen("count.in","r",stdin);freopen("count.out","w",stdout); n=read(); for(int i=1;i<n;++i){ x=read();y=read();ADD(x,y);ADD(y,x);} for(int i=1;i<=n;++i) w[i]=read(); dfs1(1,0); dfs2(1); //for(int i=1;i<=n;++i) printf("%d:son%d fa%d dep%d size%d top%d id%d rev%d\n",i,son[i],fa[i],dep[i],size[i],top[i],id[i],rev[i] ); build(1,1,cnt); q=read(); while(q--){ char st[10]; scanf("%s",st);x=read();y=read(); if(st[0]=='C') change(1,1,n,id[x],y); else{ max1=-3000000;sum1=0; ask(x,y); if(st[1]=='M') printf("%d\n",max1); else printf("%d\n",sum1); } } //fclose(stdin);fclose(stdout); return 0; }
    最新回复(0)