题目链接:https://cn.vjudge.net/problem/SPOJ-COT2
题意:n个点的树,每个点有一个权值,求两点路径上所有点权值有多少种
题解:先跑出dfs序,这里每个节点记录两次,第一次遇见记录一次,所有孩子跑完记录一次,这样在进行莫队的时候,对于每个节点第一次遇见就加上,第二次遇见就减去,若他们的lca是其中一个节点的话,直接进行即可,如果不是,那么这样进行的话lca这个节点是没算上的,所以要特判一下。
#include<bits/stdc++.h> using namespace std; const int N=40010; const int M=100100; int CM; struct node{ int l,r,lca,id; bool operator <(const node &xx)const { if(l/CM != xx.l/CM) return l/CM < xx.l/CM; else return r < xx.r; } }q[M]; int ans[M]; struct edge{ int to,nex; }e[N*2]; int head[N],len; int val[N],b[N]; int in[N],out[N],p[N*2],cnt; int dep[N]; int f[N][22]; int vis[N]; int n,m; int sum[N],res; void init() { len=0; cnt=0; for(int i=1;i<=n;i++) { head[i]=-1; sum[i]=0; vis[i]=0; for(int j=0;j<22;j++) { f[i][j]=0; } } } void add_edge(int x,int y) { e[len].to=y; e[len].nex=head[x]; head[x]=len++; } void dfs(int u,int fa) { for(int i=1;i<20;i++) if(f[f[u][i-1]][i-1]) f[u][i]=f[f[u][i-1]][i-1]; else break; in[u]=++cnt; // cout<<u<<" "; p[cnt]=u; int to; for(int i=head[u];i!=-1;i=e[i].nex) { to=e[i].to; if(to==fa) continue; f[to][0]=u; dep[to]=dep[u]+1; dfs(to,u); } out[u]=++cnt; p[cnt]=u; } int get_lca(int x,int y) { if(dep[x]<dep[y]) swap(x,y); int tmp=dep[x]-dep[y]; for(int i=0;i<20;i++) if((1<<i) & tmp) x=f[x][i]; if(x==y) return x; // cout<< x <<" --" <<y<<endl; for(int i=19;i>=0;i--) { if(f[x][i] != f[y][i]) { x=f[x][i]; y=f[y][i]; // cout<<x<<" "<<y<<endl; } } return f[x][0]; } void update(int x) { vis[x]^=1; if(vis[x]) { sum[val[x]]++; if(sum[val[x]]==1) res++; } else { sum[val[x]]--; if(sum[val[x]]==0) res--; } // cout<<x<<" --- "<<sum[val[x]]<<endl; } int main() { int x,y,z; int l,r; int tmp; while(~scanf("%d%d",&n,&m)) { CM=(int)sqrt(n*2+0.5); init(); for(int i=1;i<=n;i++) scanf("%d",&val[i]),b[i]=val[i]; sort(b+1,b+1+n); tmp=unique(b+1,b+1+n)-(b+1); for(int i=1;i<=n;i++) { val[i]=lower_bound(b+1,b+1+tmp,val[i])-b; // cout<<val[i]<<endl; } for(int i=1;i<n;i++) { scanf("%d%d",&x,&y); add_edge(x,y); add_edge(y,x); } dep[1]=1; dfs(1,0); //cout<<endl; for(int i=1;i<=m;i++) { scanf("%d%d",&x,&y); if(in[x]>in[y]) swap(x,y); z=get_lca(x,y); q[i].id=i; q[i].lca=0; // cout<<z<<endl; if(z==x || z==y) q[i].l=in[x], q[i].r=in[y]; else q[i].l=out[x], q[i].r=in[y], q[i].lca=z; } sort(q+1,q+1+m); l=1,r=0; res=0; sum[0]=1; for(int i=1;i<=m;i++) { while(l<q[i].l) { update(p[l++]); } while(l>q[i].l) { update(p[--l]); } while(r>q[i].r) { update(p[r--]); } while(r<q[i].r) { update(p[++r]); } // cout<<res<<" * "<<sum[val[q[i].lca]]<< endl; ans[q[i].id]=res+(sum[val[q[i].lca]]==0); } for(int i=1;i<=m;i++) printf("%d\n",ans[i]); } return 0; }