本文章由 WyOJ Shojo 从洛谷专栏拉取,原发布时间为 2025-03-17 14:54:16
可以看看这个 CF833B。
题意
给你一个序列,你要把这个序列分成连续的三段,每一段的权值为每一段中不同的数的个数,问你这个序列的最大权值是多少。
思路
首先发现可以预处理来第一段和最后一段的权值,设为 $ans1$ 和 $ans2$,$ans1_i$ 即为从 $1$ 到 $i$ 这一段的权值,$ans2_i$ 即为从 $i$ 到 $n$ 这一段的权值,于是我们就有了 $O(n^2)$ 的 dp 做法。
设 $f_i$ 表示以 $i$ 为第二段结尾的最大权值,答案即为 $\max f_i$。
$$f_i=\max_{j=1}^{i-1} ans1_j+val(j+1,i)+ans2_j$$
$val(i,j)$ 即为从 $i$ 到 $j$ 这一段的权值。
考虑优化。
我们发现,难点在于如何计算 $val(j+1,i)$。
考虑说我们可以把每一个 $ans1$ 扔到线段树上,一个点 $j$ 就表示说以 $j$ 为第一段的结尾最大权值,然后我们枚举第二段的结尾 $i$,每次 dp。
我们记 $a_i$ 上一次出现的位置为 $la_i$。
考虑说我们每个数只对于起点在 $[la_i+1,i]$ 这个闭区间的所有区间有 $1$ 的贡献,也就是线段树上 $[la_i,i-1]$ 这个区间。
最后一段的直接加上即可。
然后每次转移找最大值。
复杂度 $O(n\log n)$。
不理解可以结合代码食用。
Code
#include <bits\/stdc++.h>
#define endl '\n'
#define int long long
#define fi first
#define se second
using namespace std;
const int N=3e5+10;
const int inf=0x3f3f3f3f3f3f3f3f;
int n;
int a[N];
int ans1[N],ans2[N];
unordered_map<int,int> mp;
int ans;
int pre[N],nxt[N];
struct Node
{
int l,r,w;
int lt;
}tr[N<<2];
void build(int rt,int l,int r)
{
tr[rt].l=l,tr[rt].r=r;
if(l==r)
{
tr[rt].w=ans1[l];
return ;
}
int mid=(tr[rt].l+tr[rt].r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
tr[rt].w=max(tr[rt<<1].w,tr[rt<<1|1].w);
}
void pushdown(int rt)
{
int &tag=tr[rt].lt;
tr[rt<<1].w+=tag;
tr[rt<<1|1].w+=tag;
tr[rt<<1].lt+=tag;
tr[rt<<1|1].lt+=tag;
tag=0;
}
void add(int rt,int l,int r,int k)
{
if(tr[rt].l>=l&&tr[rt].r<=r)
{
tr[rt].w+=k;
tr[rt].lt+=k;
return ;
}
pushdown(rt);
int mid=(tr[rt].r+tr[rt].l)>>1;
if(l<=mid) add(rt<<1,l,r,k);
if(r>mid) add(rt<<1|1,l,r,k);
tr[rt].w=max(tr[rt<<1].w,tr[rt<<1|1].w);
}
int check(int rt,int l,int r)
{
if(tr[rt].l>=l&&tr[rt].r<=r) return tr[rt].w;
pushdown(rt);
int mid=(tr[rt].l+tr[rt].r)>>1;
int res=0;
if(l<=mid) res=max(res,check(rt<<1,l,r));
if(r>mid) res=max(res,check(rt<<1|1,l,r));
return res;
}
signed main()
{
\/\/freopen(".in","r",stdin);
\/\/freopen(".out","w",stdout);
cin.tie(0);
cout.tie(0);
ios::sync_with_stdio(false);
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++)
{
pre[i]=mp[a[i]];
mp[a[i]]=i;
ans1[i]=mp.size();
}
mp.clear();
for(int i=n;i>=1;i--)
{
nxt[i]=mp[a[i]];
mp[a[i]]=i;
ans2[i]=mp.size();
}
build(1,1,n);
for(int i=2;i<n;i++)
{
add(1,(pre[i]?pre[i]:1),i-1,1);
ans=max(check(1,1,i-1)+ans2[i+1],ans);
}
cout<<ans;
return 0;
}

鲁ICP备2025150228号