Fork me on GitHub

「BZOJ4709 / JSOI2011」 柠檬 决策单调优化DP

发现自己对 $DP$ 优化一无所知.

Links there:BZOJ-4709

题意

有 $N$ 个贝壳,每个贝壳有颜色 $s_i$ 可以取若干段的贝壳并每次指定一个颜色 $x$ ,定义在每个选择的区间内对答案的贡献为 $x \times t[x]^2$ ,其中 $t[x]$ 为在这一段中颜色 $x$ 的出现次数

思路

先要想一个结论:

假设每次取的区间为 $[L,R]$ 那么必定有 $s_L = s_R$ ,也就是两端颜色相同.这样取一定是最优的.

如果两端颜色不一样的话 那么这一点可以归到左边或者归到右边 而且对单单这一段的答案没有影响 所以正确性显然 这样找的一定最大

那么记录 $c[i]$ 为颜色 $s_i$ 从开始出现到现在的次数 显然有转移

这样的复杂度是 $O(n^2)$ 的.

考虑优化 显然不能斜率优化 因为对于 $i<j$ , $i$ 的决策会影响到 $j$ 的决策

观察转移方程 发现如果说有 $k < j \leq i$ 假如 $k$ 的转移要比 $j$ 优秀,那么在后面的转移中 $j$ 是一定用不到的 因为平方会越来越大 因此可以对每一个颜色开一个栈 每次发现栈顶的没有下面那个转移优秀就把他弹掉.

然后你就发现你WA了

因为这么做只能保证被弹掉的不再有用 但我们求的是最大值 无法保证栈顶的一定最优秀 比如出现栈中第三个更优秀的情况.

那么我们就要从入栈时间来想 仍然假设 $k < j \leq i$ ,我们假定 $k$ 转移在 $k_1$ 优于 $i$ ,$j$ 在 $j_1$ 优于

$i$ 如果有 $k_1 < j_1$ 会发生什么呢.

仔细模拟之后我们发现 栈中第一个元素比第二个优秀,而下面还有一个更优秀的 $k_1$ 我们没把它取出来.

这就是万恶之源 因此遇到 $k_1 < j_1$ 的情况 就弹出 保证栈内元素比上一个元素更优的时间也是单调的 因此二分去找这个 “更优时间” 就行啦.

复杂度$O(nlogn)$ 可以搞过去.

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
//Keep pluggin',this is your only outlet.
#include<bits/stdc++.h>
#define MM(x,y) memset(x,y,sizeof(x))
#define MCPY(a,b) memcpy(a,b,sizeof(b))
#define pb push_back
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,b,a) for(int i=b;i>=a;i--)
#define fi first
#define mp make_pair
#define se second
using namespace std;
#define int long long

inline int quickpow(int m,int n,int p){int b=1;while(n){if(n&1)b=b*m%p;n=n>>1;m=m*m%p;}return b;}
inline int quickmul(int x,int y,int mod){return (x*y-(int)((long double)x/mod*y)*mod+mod)%mod;}
inline int getinv(int x,int p){return quickpow(x,p-2,p);}
inline int read(void){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){f=ch=='-'?-1:1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x * f;
}

const int MAXN = 1e5 + 100;
vector<int>sta[MAXN];
int n,col[MAXN],a[MAXN],f[MAXN],c[MAXN];

inline int Calc(int x,int y) {
return f[x-1] + (c[y]-c[x]+1) * (c[y]-c[x]+1) * a[y];
}

inline int Tim(int x,int y) {
int l = 1,r = n,mid,cur = (c[y] - c[x]);
while(l < r) {
mid = (l + r) >> 1;
if (f[x-1] + (mid + cur) * (mid + cur) * a[x] < f[y-1] + mid * mid * a[y])
l = mid + 1;
else r = mid;
}
return r;
}


///------------------head------------------
signed main(signed argc, char *argv[]){
n = read();
rep(i,1,n) a[i] = read();
for (int i = 1; i <= n; ++i) {
c[i] = ++col[a[i]];
while((signed)sta[a[i]].size()>=2 && Tim(sta[a[i]][(signed)sta[a[i]].size()-2],i) < Tim(sta[a[i]][(signed)sta[a[i]].size()-1],i))
sta[a[i]].pop_back();
sta[a[i]].push_back(i);
while((signed)sta[a[i]].size()>=2 && Calc(sta[a[i]][(signed)sta[a[i]].size()-2],i) > Calc(sta[a[i]][(signed)sta[a[i]].size()-1],i))
sta[a[i]].pop_back();
f[i] = Calc(sta[a[i]][sta[a[i]].size()-1],i);
for (int j = 0; j < sta[a[i]].size()-1; ++j)
f[i] = max(f[i],Calc(sta[a[i]][j],i));
//printf("%lld\n",f[i]);
}
printf("%lld\n",f[n]);
return 0;
}

/* Examples: */
/*

*/

/*

*/