Logo __vector__ 的博客

博客

日照夏令营 day2 regress 题解

...
__vector__
2025-12-01 12:55:47

本文章由 WyOJ Shojo 从洛谷专栏拉取,原发布时间为 2022-07-26 16:27:31

做法

对于一个点的 $k$ 级祖先,可以用倍增或长链剖分求出,十分容易。

对于一个点的所有 $k$ 级子孙,因为一个点 $i$ 的 $k$ 级子孙的深度是 $dep_i - k$,所以要查询的是,以 $i$ 为根的子树中,有多少个深度为 $dep_i + k$ 的。
所以可以对于每个点 $i$,建一棵权值线段树,维护以 $i$ 为根的子树中深度为 $j$(代表任意深度) 的节点有多少个。
这样通过线段树合并就能处理出所有信息了。

Code

#include <bits\/stdc++.h>
using namespace std;
namespace Main
{
	typedef long long ll;
	const int maxn=3e5+5;
	int n,m;
	int log[maxn];
	int head[maxn];
	struct EDGE
	{
		int to,nxt;
	}edge[maxn<<1];
	int cnt=0;
	inline void add(int u,int to)
	{
		edge[++cnt].to=to;
		edge[cnt].nxt=head[u];
		head[u]=cnt;
	}
	int h[maxn],hs[maxn],fa[20][maxn];
	int dep[maxn];
	int ans[maxn];
	\/\/这里的h[i]是以i为根的子树的深度
	int count[maxn];
	\/\/count[i]是深度为i的点的个数
	void dfs1(int u,int _fa)
	{
		h[u]=dep[u]=dep[_fa]+1;
		fa[0][u]=_fa;
		count[dep[u]]++;
		for(int i=1;i<=19;i++)
		{
			fa[i][u]=fa[i-1][fa[i-1][u]];
		}
		for(int i=head[u];i;i=edge[i].nxt)
		{
			int to=edge[i].to;
			if(to==_fa)continue;
			dfs1(to,u);
			h[u]=max(h[u],h[to]);
			if(h[to]>h[hs[u]])hs[u]=to;
		}
		h[u]++;
	}
	int top[maxn];
	vector<int> up[maxn],down[maxn];
	void dfs2(int u,int _top)
	{
		top[u]=_top;
		if(u==top[u])
		{
			for(int i=0,now=u;i<=h[u]-dep[u];i++)
			{

				up[u].emplace_back(now);
				now=fa[0][now];
			}
			for(int i=0,now=u;i<=h[u]-dep[u];i++)
			{
				down[u].emplace_back(now);
				now=hs[now];
			}
		}
		if(hs[u])
		{
			dfs2(hs[u],_top);
		}
		for(int i=head[u];i;i=edge[i].nxt)
		{
			int to=edge[i].to;
			if(to==fa[0][u]||to==hs[u])continue;
			dfs2(to,to);
		}
	}
	int ask(int x,int k)
	{
		if(k==0)
		{
			return x;
		}
		int mbsd=dep[x]-k;\/\/目标深度
		if(mbsd<=0)return -1;
		int dqd=top[fa[log[k]][x]];
		if(dep[dqd]<mbsd)
		{
			dqd=down[dqd][mbsd-dep[dqd]];
		}
		if(dep[dqd]>mbsd)
		{
			dqd=up[dqd][dep[dqd]-mbsd];
		}
		return dqd;
	}
	struct Tree
	{
	    int ls,rs,val;
	}tree[maxn*100];
	int nodecnt;
	inline void push_up(int node)
	{
	    tree[node].val=tree[tree[node].ls].val+tree[tree[node].rs].val;
	}
	int rt[maxn];
    int modify(int node,int l,int r,int pos,int val)
    {
        if(!node)node=++nodecnt;
        if(l==r)
        {
            tree[node].val+=val;
            return node;
        }
        int mid=l+r>>1;
        if(mid>=pos)
        {
            tree[node].ls=modify(tree[node].ls,l,mid,pos,val);
        }
        else tree[node].rs=modify(tree[node].rs,mid+1,r,pos,val);
        push_up(node);
        return node;
    }
    int merge(int a,int b,int l,int r)
    {
        if(!a||!b)return a|b;
        int root=++nodecnt;
        if(l==r)
        {
            tree[root].val=tree[a].val+tree[b].val;
            return root;
        }
        int mid=l+r>>1;
        tree[root].ls=merge(tree[a].ls,tree[b].ls,l,mid);
        tree[root].rs=merge(tree[a].rs,tree[b].rs,mid+1,r);
        push_up(root);
        return root;
    }
    int query(int node,int l,int r,int pos)
    {
        if(l==r)
        {
            return tree[node].val;
        }
        int mid=l+r>>1;
        int ans=0;
        if(mid>=pos)ans=query(tree[node].ls,l,mid,pos);
        else ans=query(tree[node].rs,mid+1,r,pos);
        return ans;
    }
    struct Question
    {
        int k,id;
        Question(int k2,int id2)
        {
            k=k2;
            id=id2;
        }
    };
    vector<Question> qs[maxn];
    void solve(int u,int _fa)
    {
        for(int i=head[u];i;i=edge[i].nxt)
        {
            int to=edge[i].to;
            if(to==_fa)continue;
            solve(to,u);
            rt[u]=merge(rt[u],rt[to],1,n+1);
        }
        rt[u]=modify(rt[u],1,n+1,dep[u],1);
        for(int i=0;i<qs[u].size();i++)
        {
            ans[qs[u][i].id]=query(rt[u],1,n+1,dep[u]+qs[u][i].k)-1;
            if(ans[qs[u][i].id]<0)ans[qs[u][i].id]=0;
        }
    }
	void main()
	{
		scanf("%d",&n);
		for(int i=2;i<=n;i++)
		{
			log[i]=log[i>>1]+1;
		}
		int __fa;
		for(int i=1;i<=n;i++)
		{
			scanf("%d",&__fa);
			add(i,__fa);
			add(__fa,i);
		}
		dfs1(0,-1);
		dfs2(0,0);
		scanf("%d",&m);
		for(int i=1;i<=m;i++)
		{
			int xi,ki;
			scanf("%d%d",&xi,&ki);
			int zx=ask(xi,ki);
			if(zx!=-1)
            {
                qs[zx].emplace_back(ki,i);
            }

		}
		solve(0,-1);
		for(int i=1;i<=m;i++)
        {
            printf("%d ",ans[i]);
        }
	}
}
int main()
{
	Main::main();
	return 0;
}

评论

暂无评论

发表评论

可以用@mike来提到mike这个用户,mike会被高亮显示。如果你真的想打“@”这个字符,请用“@@”。