本文章由 WyOJ Shojo 从洛谷专栏拉取,原发布时间为 2025-05-09 18:01:18
题解好像大部分是差分拆贡献,这是 SD 二轮省集 Harry27182 老师的做法,感觉很牛,在此记录。
注:以下用 $w(x,y)$ 代表原题中的 $f(x,y)$,$mn_u$ 表示 $u$ 子树内 $a$ 的最小值。
显然可以先把 $a$ 离散化,之后考虑对最终答案相同的方案整体计算贡献。设 $f_{u,i}$ 表示对于所有合法的 $w(u,x)$,它们中值为 $i$ 的期望个数。转移时讨论子树内路径最小值是否仍为最小,乘上组合数转移。
具体地,对于 $u$ 的儿子 $v$,若以 $v$ 为起点时的最小值 $i$ 仍最小,需要 $u$ 其他子树中 $mn_x\le i$ 的所有 $x$ 在搜索时均排在 $v$ 之后。此处计算方案数时,可设 $u$ 共有 $S$ 个儿子,则先拿出 $mn_x\le i$ 的 $c$ 个(其中必然包含 $v$),要求 $v$ 在开头,其他随意排列,剩余的 $(S-c)$ 个也随意排列。之后将 $S$ 个位置分类并分别放入,即为合法排列数。因此概率为 $$ \frac{{S\choose c}\times (c-1)!\times (S-c)!}{S!}=\frac{S!}{c!\times (S-c)!}\times (c-1)!\times (S-c)!\times \frac 1{S!}=\frac 1 c. $$ 转移即为 $f_{u,i}\leftarrow f_{v,i}\times\frac 1 {\sum_{x\in son(u)}[mn_x\le i]}$。
另一种情况是最小值变成了另一子树 $x$ 的最小值 $mn_x$,最终结束于 $v$ 子树内的终点,这里显然有限制 $mn_x\le i$。此时先将 $u$ 的所有子树按 $mn$ 从小到大排序,设 $rk_x$ 为排序后 $x$ 的排名。则所有排在 $x$ 前面的子树中,除 $v$ 外的其他子树必须排在 $v$ 之后。
那么需要分 $v$ 在 $x$ 之前和之后讨论,以确定需要在 $v$ 之后的子树个数。这里以 $v$ 在 $x$ 之前为例,仿照上面可以把 $rk_x$ 个子树拿出来单独排列,并将 $x,v$ 分别放到前两位,概率即为 $$ \frac{{S\choose rk_x}\times (rk_x-2)!\times (S-rk_x)!}{S!}=\frac{S!}{rk_x!\times (S-rk_x)!}\times (rk_x-2)!\times (S-rk_x)!\times \frac 1{S!}=\frac 1 {rk_x(rk_x-1)}. $$ 同理可得 $v$ 在 $x$ 之后时系数为 $\frac 1{rk_x(rk_x+1)}$。转移时为降低复杂度,可以枚举 $x$,从而用前缀和优化省去对 $v,i$ 的枚举,即 $$ f_{u,mn_x}\leftarrow\frac 1 {rk_x(rk_x-1)}\sum_{rk_v<rk_x}\sum_{i\ge mn_x} f_{v,i}+\frac 1{rk_x(rk_x+1)}\sum_{rk_v>rk_x}\sum_{i\ge mn_x} f_{v,i}. $$
这两种转移完成后,由于从 $u$ 出发必然经过 $a_u$,需要把所有 $f_{u,i}$ 的 $i$ 对 $a_u$ 取 min,即将大于 $a_u$ 的 DP 值均加到 $f_{u,a_u}$ 上并清空。最后给答案加上所有 $w(u,x)$ 的贡献,即 $\sum i\times f_{u,i}$。
另外注意到会有相等的 $a$ 值,这时钦定 $f_{v,i}$ 中的 $i$ 为等大的数中最大的,其余相等的 $mn_x$ 中 $rk$ 在前的较小,这样定义后整个 DP 过程即上述,可以实现不重不漏。目前时间复杂度为 $O(n^2)$。
考虑优化 DP 过程,注意到第一种转移是对应位置累加,且每个位置上转移系数相等,可以使用线段树合并。同时由于转移系数只与不超过 $i$ 的 $mn_x$ 个数有关,不同的区间只有 $O(deg_u)$ 个,可以用区间乘解决。
进行第二种转移时,可以开一棵临时的线段树,通过合并得到 $rk$ 数组上前后缀的线段树,再进行 $i\ge mn_x$ 的区间查询,最后进行单点加即可。注意两者的系数不同,需要分别顺序和逆序做。最后对 $a_u$ 取 min 只需区间查询,清空即为区间乘 $0$,也是区间乘。
所以需要实现线段树合并,并支持区间乘,单点加,区间求和,这些操作均不难实现,时空复杂度 $O(n\log n)$。由于有区间乘和临时的前后缀线段树,最终空间大概需要 $4$ 倍的 $n\log n$,$3\times 10^7$ 就足够了。
附上代码:
#include<iostream>
#include<vector>
#include<algorithm>
#define pb push_back
#define mid ((l+r)>>1)
using namespace std;
const int N=4e5+10;
const int P=4e5;
const int M=3e7+10;
const int mod=998244353;
void add(int &a,int b) {a+=b;if(a>=mod)a-=mod;}
int n,m,rot,res,a[N],b[N],x[N],mn[N],rt[N],inv[N];
vector <int> e[N];
bool cmp(int i,int j) {return mn[i]<mn[j];}
struct sgmtt
{
int t,lc[M],rc[M],w[M],k[M],tag[M];
void cle(int u) {lc[u]=rc[u]=w[u]=k[u]=0,tag[u]=1;}
void pushup(int u) {w[u]=w[lc[u]],k[u]=k[lc[u]],add(w[u],w[rc[u]]),add(k[u],k[rc[u]]);}
void pt(int u,int x) {w[u]=1ll*w[u]*x%mod,k[u]=1ll*k[u]*x%mod,tag[u]=1ll*tag[u]*x%mod;}
void pushdown(int u) {if(tag[u]!=1) pt(lc[u],tag[u]),pt(rc[u],tag[u]),tag[u]=1;}
void update(int u,int l,int r,int L,int R,int x)
{
if(!u||L>R) return;
if(l>=L&&r<=R) {pt(u,x); return;}
pushdown(u);
if(L<=mid) update(lc[u],l,mid,L,R,x);
if(R>mid) update(rc[u],mid+1,r,L,R,x);
pushup(u);
}
void change(int &u,int l,int r,int p,int x)
{
if(!u) u=++t,cle(t);
if(l==r) {add(k[u],x),w[u]=1ll*k[u]*b[l]%mod; return;}
pushdown(u);
if(p<=mid) change(lc[u],l,mid,p,x);
else change(rc[u],mid+1,r,p,x);
pushup(u);
}
int query(int u,int l,int r,int L,int R)
{
if(!u||L>R) return 0;
if(l>=L&&r<=R) return k[u];
pushdown(u); int tr=0;
if(L<=mid) add(tr,query(lc[u],l,mid,L,R));
if(R>mid) add(tr,query(rc[u],mid+1,r,L,R));
return tr;
}
int merg(int u,int v,int l,int r)
{
if(!u||!v) return u+v;
int p=++t; cle(t);
if(l==r) w[p]=w[u],k[p]=k[u],add(w[p],w[v]),add(k[p],k[v]);
else pushdown(u),pushdown(v),lc[p]=merg(lc[u],lc[v],l,mid),rc[p]=merg(rc[u],rc[v],mid+1,r);
if(l<r) pushup(p);
return p;
}
}T;
void dfs(int u,int fat)
{
mn[u]=a[u]; vector <int> p;
for(int v:e[u]) if(v!=fat)
{
dfs(v,u),p.pb(v);
mn[u]=min(mn[u],mn[v]);
}
sort(p.begin(),p.end(),cmp);
int s=p.size(),cur=0;
for(int i=0;i<s;i++)
{
x[i]=1ll*T.query(cur,1,m,mn[p[i]],m)*inv[i]%mod*inv[i+1]%mod;
cur=T.merg(cur,rt[p[i]],1,m);
}
cur=0;
for(int i=s-1;~i;i--)
{
add(x[i],1ll*T.query(cur,1,m,mn[p[i]],m)*inv[i+1]%mod*inv[i+2]%mod);
cur=T.merg(cur,rt[p[i]],1,m),rt[u]=T.merg(rt[u],rt[p[i]],1,m);
}
for(int i=1;i<s;i++) if(mn[p[i]]!=mn[p[i-1]]) T.update(rt[u],1,m,mn[p[i-1]],mn[p[i]]-1,inv[i]);
if(s) T.update(rt[u],1,m,mn[p[s-1]],m,inv[s]);
for(int i=0;i<s;i++) T.change(rt[u],1,m,mn[p[i]],x[i]);
T.change(rt[u],1,m,a[u],T.query(rt[u],1,m,a[u]+1,m)+1),T.update(rt[u],1,m,a[u]+1,m,0),add(res,T.w[rt[u]]);
}
void sol()
{
cin>>n>>rot,res=T.t=0;
for(int i=1;i<=n;i++) cin>>a[i],b[i]=a[i],rt[i]=0,e[i].clear();
sort(b+1,b+1+n),m=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+1+m,a[i])-b;
for(int i=1,u,v;i<n;i++) cin>>u>>v,e[u].pb(v),e[v].pb(u);
dfs(rot,0),cout<<res<<'\n';
}
int main()
{
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
inv[0]=inv[1]=1;
for(int i=2;i<=P;i++) inv[i]=1ll*(mod-mod\/i)*inv[mod%i]%mod;
int TT; cin>>TT;
while(TT--) sol();
return 0;
}

鲁ICP备2025150228号