Fork me on GitHub

「CF739E」 Gosha is hunting WQS二分优化dp

Links there:CF739E

写了个 $O(n^3)$的概率转移 $dp$ 被劝退了不会优化 只能去补姿势

题意

有 $N$ 个怪物,你有 $a$ 个 $Pokeball$ 和 $b$ 个 $Ultraball$.

给出每个怪物分别被 $Pokeball$ 和被 $Ultraball$ 捕获的概率 求最大能捕获个数的期望.

思路

上来抱着 $codeforces $ 跑的巨快的心态写了一发 $O(n^3)$

用 $f(i,j,k)$ 表示 前 $i$ 个怪物用 $j$ 个 $Pokeball$, $k$ 个 $Ultraball$ 的最大期望.

这个转移显然,讨论一下每个怪物

不用球 / 用Pokeball / 用Ultraball / 都用

大概写出来这样子

1
2
3
4
5
6
7
8
9
10
11
12
for (int i = 1; i <= n; ++i) {
for (int j = 0; j <= a; ++j)
for (int k = 0; k <= b; ++k){
int lst = (i & 1) ? 0 : 1;
upd(f[lst^1][j][k],f[lst][j][k]);
if (j >= 1) upd(f[lst^1][j][k],f[lst][j-1][k] + p[i]);
if (k >= 1) upd(f[lst^1][j][k],f[lst][j][k-1] + u[i]);
if (j >= 1 && k >= 1)
upd(f[lst^1][j][k],f[lst][j-1][k-1] + 1.0 - (1.0 - p[i]) * (1.0 - u[i]));
if (i == n) upd(ans,f[lst^1][j][k]);
}
}

然后你就发现你 $T$ 飞了.

make

用 $WQS$ 二分的思想

原问题显然必须把 $a,b$ 取光.

所以这实际上是一个有2个限制的取 $K$ 个物品的问题

固定 $f(i,a,k)$ 的 $i,a$ 的时候 , 发现 $f(i,a,k)$ 是关于 $k$ 的凸函数.

这是显然的对于固定的 $i$ ,你扔的球的个数越多的话收益越差.

类似的 , 固定 $f(i,j,b)$ 的 $i,b$ 的时候 , 发现 $f(i,j,b)$ 是关于 $j$ 的凸函数.

那么不如给每个球加两个限制 $cost1 \space cost2$ 表示每个类型的球额外的代价.

二分套二分,最后把这一部分影响还原就是答案.

因为 $cost1 \space cost2$ 的范围在 $[0,1]$ 之间 所以二分的次数很少 (看你 eps)

几乎就是 $O(n\times x^2)$ ,$x$ 为二分常数 可小了.

跑的还真是快呢.

make2

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
//Keep pluggin',this is your only outlet.
#include<bits/stdc++.h>
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#define MM(x,y) memset(x,y,sizeof(x))
#define MCPY(a,b) memcpy(a,b,sizeof(b))
#define pb push_back
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,b,a) for(int i=b;i>=a;i--)
#define fi first
#define mp make_pair
#define se second
using namespace std;
#define int long long

inline int quickpow(int m,int n,int p){int b=1;while(n){if(n&1)b=b*m%p;n=n>>1;m=m*m%p;}return b;}
inline int quickmul(int x,int y,int mod){return (x*y-(int)((long double)x/mod*y)*mod+mod)%mod;}
inline int getinv(int x,int p){return quickpow(x,p-2,p);}
inline int read(void){
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){f=ch=='-'?-1:1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x * f;
}
double f[2010];
double p[2010],u[2010],pu[2010];
int numa[2010],numb[2010];
int a,b,n;
inline void upd(double &x,double y) {x = max(x,y);}
const double eps = 1e-10;
inline void Calc(double c1,double c2){
f[0] = 0.0; numa[0] = numb[0] = 0;
for (int i = 1; i <= n; ++i) {
f[i] = f[i-1] ; numb[i] = numb[i-1]; numa[i] = numa[i-1];
if (f[i-1] + p[i] - c1 > f[i]) {
f[i] = f[i-1] + p[i] - c1;
numa[i] = numa[i-1] + 1;
numb[i] = numb[i-1];
}
if (f[i-1] + u[i] - c2 > f[i]) {
f[i] = f[i-1] + u[i] - c2;
numa[i] = numa[i-1];
numb[i] = numb[i-1] + 1;
}
if (f[i-1] + pu[i] - (c1 + c2) > f[i]) {
f[i] = f[i-1] + pu[i] - (c1 + c2);
numa[i] = numa[i-1] + 1;
numb[i] = numb[i-1] + 1;
}
}
}

///------------------head------------------
signed main(signed argc, char *argv[]){
n = read(),a = read(),b = read();
rep(i,1,n) scanf("%lf",&p[i]);
rep(i,1,n) scanf("%lf",&u[i]);
rep(i,1,n) pu[i] = 1.0 - (1.0 - p[i]) * (1.0 - u[i]);
//O(n^3) dp
// double ans = 1e-10;
// for (int i = 1; i <= n; ++i) {
// for (int j = 0; j <= min(i,a); ++j)
// for (int k = 0; k <= b; ++k){
// int lst = (i & 1) ? 0 : 1;
// upd(f[lst^1][j][k],f[lst][j][k]);
// if (j >= 1) upd(f[lst^1][j][k],f[lst][j-1][k] + p[i]);
// if (k >= 1) upd(f[lst^1][j][k],f[lst][j][k-1] + u[i]);
// if (j >= 1 && k >= 1)
// upd(f[lst^1][j][k],f[lst][j-1][k-1] + 1.0 - (1.0 - p[i]) * (1.0 - u[i]));
// if (i == n) upd(ans,f[lst^1][j][k]);
// }
// }
double L = 0,R = 1,L1,R1;
while(R - L > eps) {
double mid1 = (L + R) / 2;
L1 = 0,R1 = 1;
while(R1 - L1 > eps){
double mid2 = (L1 + R1) / 2;
Calc(mid1,mid2);
if (numb[n] > b) L1 = mid2;
else R1 = mid2;
}
Calc(mid1,R1);
if (numa[n] > a) L = mid1;
else R = mid1;
}
Calc(R,R1);
double ans = f[n] + a * R + b * R1;
printf("%.10lf\n",ans);
return 0;
}

/* Examples: */
/*

*/

/*

*/

发现自己一万年没更blog了.之前的题会陆陆续续补上来吧.