Fork me on GitHub

FWT-FMT子集变换卷积学习笔记

终于被论文题劝退之后来补这个坑了

引言

被HAOI2015按位或劝退

滚回去补vfk的论文

so happy

参见2015国家队论文《集合幂级数的性质与应用及其快速算法》

集合幂级数

引入

之前对于有关集合的计数问题 ,一般的常规思路是用 $f_S$ 去表示方案数然后去递推.

然而这个时候你看一眼题面 意识到自己 $O(4^n)$ 甚至连暴力都不如 (不对 这就是暴力)

于是我们需要一个东西来优化递推转移

还是用数列的理论来类比.之前多项式的一些理论是不是在集合中也可以使用呢?

我们曾经对于一个数列 $f0,f_1…$ 用一个生成函数来描述 即 $f(x) = \sum { k=0 } ^ { inf } f_k\times x^k$

左边是严格的数列递推 右边是优美的多项式 有成套的工具(板子)去解决.

但是集合的处理怎么办?

引入集合幂级数的概念

定义

令 $f​$ 是定义域在集合 $U​$ 以内,映射到 $F​$ 的一个集合幂级数 , 对于每一个定义域中的 $S,S \in U​$,设 $f_S​$ 为该 $S​$ 带入函数中所得到的值.

结论 确定了每一个 $S$ 所对应的 $f_S$ , $f$ 也可以随之确定且唯一.

我们用类似生成函数的定义方式的式子来定义 $f$ . $f = \sum _ { S \in U } f_S \times x ^ { S } $

这里的 $ f_S $ 并非系数 而是一个 $U$ 的子集所对应的一个点值.

我还是举论文上的例子吧。。。

$U={ 1,2 } , f(x) = 5x^{ \varnothing }$$+ 8x^ { { 1 } }$ $+$ $13x^ { { 1,2}}$

则这个集合幂级数里,$f({ 1,2 } ) = 13$,$f(\varnothing) = 5,f( { 1}) = 8$

考虑这个玩意的运算.

加减法显然系数相加相减即可.但是乘法呢?我们对于集合的运算取并取交 得到的似乎并不相同啊

引出子集变换集合卷积

FMT(Fast Mobius Transform)快速莫比乌斯变换

取集合运算 $L \subseteq U,R \subseteq U$,$L * R = L \cup R$

$f(x) = \sum { S \subseteq L} f_S x^S ,g(x) = \sum { S \subseteq R}g_S x^S$

它们的卷积为 $h(x) = \sum { L \subseteq U} \sum { R\subseteq U}f_Lg_R(L \cup R = S)$

Bruteforce

暴力枚举集合 复杂度 $O(3^N)$ 美滋滋

Divide & Conquer

考虑分治乘 对于现有的集合 $U = \ { 1,2,3,4,…n}$ 考虑 $ n $ 单独提出来 分类.

现在记号 $F_1(x)$ 表示现有的集合中不包含 $n$ 的 $F_2(x)$ 表示现有的集合中包含 $n$ 但是要除去的

$f * g = (F_1+ { x_nF_2})(G_1+x_nG_2)$

然后展开分治算就可以了.

$T(n) = T(n-1)+O(2^N)$ 由主定理 $T(n) = O(N\times2^N)$

据炸鸡小弟声称 分治乘往往跑的比$FMT$快 但也并不知道是为什么…

Fast Mobius Transform

证明看论文 这里只讲结论

定义一个幂级数的莫比乌斯变换为

$FS = \sum { T \subseteq S}f_T$

反过来 我们定义 $F$ 的莫比乌斯反演为 $f$

推个容斥错位相减 $fS = \sum { T \subseteq S} { (-1)^ { |S-T|}F_T}$

我们算卷积的时候 $h(x) = \sum { L \subseteq U} \sum { R\subseteq U}f_Lg_R(L \cup R = S)$

两边同时反演.最后可以得到 $H = F * G$ 即 反演后依然成立

最后两边反演即可 . 这样我们算发 $ f,g$ 的卷积 只要先莫比乌斯变换 然后相乘之后 再莫比乌斯反演 就可以啦

复杂度$O(N \times 2^N)$

模板 StupidOJ47

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
//my vegetable has exploded. :(
#include<bits/stdc++.h>
#define max(x,y) (x>y?x:y)
#define min(x,y) (x<y?x:y)
#define MM(x,y) memset(x,y,sizeof(x))
#define MCPY(a,b) memcpy(a,b,sizeof(b))
#define pb push_back
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define per(i,a,b) for(int i=b;i>=a;i--)
#define fi first
#define se second
using namespace std;
#define int long long

inline int quickpow(int m,int n,int p) { int b=1;while(n) { if(n&1)b=b*m%p;n=n>>1;m=m*m%p;}return b;}
inline int getinv(int x,int p) { return quickpow(x,p-2,p);}
inline int read(void) {
int x=0,f=1;char ch=getchar();
while(!isdigit(ch)) { f=ch=='-'?-1:1;ch=getchar();}
while(isdigit(ch)) { x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return x * f;
}
const int MAXN = 2e6+100;
int n,a[MAXN],b[MAXN],l = 0;
void fmt(int a[],int len,int type) { //type = -1 反演
for (int i = 0; i < len; ++i)
for (int j = 0; j < (1 << len) - 1; ++j)
a[j] += type * a[(1<<i)^j] * (j >> i & 1);
}
///------------------head------------------
signed main(signed argc, char *argv[]) {
n = read();
for (int i = 1; i <= n; i = i << 1, ++l);
rep(i,0,n-1) a[i] = read();
rep(i,0,n-1) b[i] = read();
fmt(a,l,1); fmt(b,l,1);
rep(i,0,n-1) a[i] = a[i] * b[i];
fmt(a,l,-1);
rep(i,0,n-1) printf("%lld ",a[i]);
return 0;
}

/* Examples: */
/*

*/

/*

*/

Fast Walsh-Hadamard Transform

本质上就是每一位只可能是 $1,0$ 的多项式

懒得证明正确性了 但是这是确实存在的

$FWT(A+B) = FWT(A) +FWT(B)$

$FWT(A \oplus B) = FWT(A) \times FWT(B)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
void fwt_and(int *a,int type) { 
for (int i = 2; i <= n; i <<= 1)
for (int p = (i >> 1) , j = 0; j < n; j += i)
for (int k = j; k < j + p; ++k)
a[k+p] += a[k] * type;
}
void fwt_or(int *a,int type) {
for (int i = 2; i <= n; i <<= 1)
for (int p = (i >> 1), j = 0; j < n; j += i)
for (int k = j; k < j + p; ++k)
a[k] += a[k+p] * type;
}
void fwt_xor(int *a,int type) { //mod意义下
for (int i = 2; i <= n; i <<= 1)
for (int p = (i >> 1), j = 0; j < n; j += i)
for (int k = j; k < j + p; ++k) {
int x = (a[k] + a[k+p]),y = (a[k]-a[k+p]);
a[k] = x; a[k+p] = y;
}
int inv = getinv(quickpow(2,n,Mod),Mod);
if (type == -1) for (int i = 0; i < n; ++i) a[i] = a[i] * inv % Mod;
}