数学相关
匹配相关

FFT

HJWJBSR posted @ 2015年6月17日 17:16 in 专题 , 424 阅读

其实早在年初就学过几次FFT了,算导也翻过两遍,就是从没写过。

昨天做题目异常烦躁,于是滚去背了下FFT模板(本来是准备早自习抽个时间背的,但是发现这个还是蛮好背的,毕竟带花树都背下来了= =),然后水了两道题。

UOJ#34

贴个模板= =,其实我是从耗时少的里面随便找了一个@kiana的背的,虽然还没看懂这个非递归版本是怎么搞的(其实是不想去认真看了= =),反正感觉也够用了。

#include <iostream>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 300050
const double pi=acos(-1.0);
struct complex
 {
 	double r,i;
 	complex (){}
 	complex (double _r,double _i) {r=_r;i=_i;}
 } a[N],b[N],c[N];
complex operator+(const complex &A,const complex &B)
 {return complex(A.r+B.r,A.i+B.i);}
complex operator-(const complex &A,const complex &B)
 {return complex(A.r-B.r,A.i-B.i);}
complex operator*(const complex &A,const complex &B)
 {return complex(A.r*B.r-A.i*B.i,A.r*B.i+A.i*B.r);}
int n,m,k,len,bit,rev[N];
inline int Read()
 {
 	int x=0;char y;
 	do y=getchar(); while (y<'0'||y>'9');
 	do x=x*10+y-'0',y=getchar(); while (y>='0'&&y<='9');
 	return x;
 }
void FFT(complex* a,int n,int f)
 {
 	for (int i=0;i<n;i++)
 	 if (i<rev[i]) swap(a[i],a[rev[i]]);
 	for (int p=1;p<n;p <<= 1)
 	 {
 	 	complex wn(cos(pi/p),sin(pi/p)*f);
 	 	for (int i=0;i<n;i+=(p << 1))
 	 	 {
 	 	 	complex e(1,0);
 	 	 	for (int j=0;j<p;j++,e=e*wn)
 	 	 	 {
 	 	 	 	complex x=a[i+j],y=e*a[i+j+p];
 	 	 	 	a[i+j]=x+y;a[i+j+p]=x-y;
 	 	 	 }
 	 	 }
 	 }
 	if (f!=1)
 	 for (int i=0;i<n;i++) a[i].r/=n;
 	return;
 }
int main()
 {
 	n=Read()+1;m=Read()+1;k=n+m-1;len=1;bit=0;
 	while (len<k) len <<= 1,bit++;
 	for (int i=0;i<n;i++) a[i].r=Read();
 	for (int i=0;i<m;i++) b[i].r=Read();
 	for (int i=1;i<len;i++)
 	  rev[i]=(rev[i >> 1] >> 1)|((i&1) << (bit-1));
 	FFT(a,len,1);FFT(b,len,1);
 	for (int i=0;i<len;i++)
 	  c[i]=a[i]*b[i];
 	FFT(c,len,-1);
 	for (int i=0;i<k;i++)
 	  printf("%d ",(int)(c[i].r+0.5));
 	printf("\n");
 	return 0;
 }

BZOJ3527万径人踪灭:

其实就是统计回文子序列个数-回文子串个数,后者直接上manacher,前者就可以直接枚举字符直接套FFT算了。

#include <iostream>
#include <cstring>
#include <cmath>
#include <cstdio>
#include <algorithm>
using namespace std;
#define N 300050
#define M 1000000007
#define eps 1e-2
#define ll long long
const double pi=acos(-1.0);
struct complex
 {
 	double r,i;
 	complex(){}
 	complex(double _r,double _i) {r=_r;i=_i;}
 } a[N],b[N],c[N];
complex operator+(const complex &A,const complex &B)
 {return complex(A.r+B.r,A.i+B.i);}
complex operator-(const complex &A,const complex &B)
 {return complex(A.r-B.r,A.i-B.i);}
complex operator*(const complex &A,const complex &B)
 {return complex(A.r*B.r-A.i*B.i,A.r*B.i+A.i*B.r);}
int n,m,len,bit,rev[N],ss,d[N],ft[N];
ll ans;
void TAT(complex* a,int n,int f)
 {
 	for (int i=0;i<n;i++)
 	 if (i<rev[i]) swap(a[i],a[rev[i]]);
 	for (int p=1;p<n;p <<= 1)
 	 {
 	 	complex wn(cos(pi/p),sin(pi/p)*f);
 	 	for (int i=0;i<n;i+=(p << 1))
 	 	 {
 	 	 	complex e(1,0);
 	 	 	for (int j=0;j<p;j++,e=e*wn)
 	 	 	 {
 	 	 	 	complex x=a[i+j],y=e*a[i+j+p];
 	 	 	 	a[i+j]=x+y;a[i+j+p]=x-y;
 	 	 	 }
 	 	 }
 	 }
 	if (f!=1)
 	 for (int i=0;i<n;i++) a[i].r/=n;
 	return;
 }
void QAQ()
 {
 	int i,j,k,l,q,w,e;
 	len=1;bit=0;
 	while (ss>len) len <<= 1,bit++;
 	for (i=1;i<len;i++)
 	  rev[i]=(rev[i >> 1] >> 1)|((i&1) << (bit-1));
 	TAT(a,len,1);TAT(b,len,1);
 	for (i=0;i<len;i++) c[i]=a[i]*b[i];
 	TAT(c,len,-1);
    return;
 }
ll Manacher()
 {
 	int i,j,k,l,q,w,e;ll s;
 	memset(ft,0,sizeof(ft));
 	q=s=0;e=-1;
 	for (i=0;i<=ss;i++)
 	 {
 	 	ft[i]=i>e?1:min(ft[2*q-i],e-i+1);
 	 	while (i-ft[i]>=0&&i+ft[i]<=ss&&d[i-ft[i]]==d[i+ft[i]])
 	 	  ft[i]++;
 	 	if (i+ft[i]-1>e)
 	 	 {
 	 	 	e=i+ft[i]-1;q=i;
 	 	 }
 	 	s+=ft[i]/2;
 	 }
 	return s%M;
 }
inline ll Quick_mi(ll x,int y)
 {
 	ll z=1;
 	while (y)
 	 {
 	 	if (y&1) z=(z*x)%M;
 	 	x=(x*x)%M;
 	 	y >>= 1;
 	 }
 	return z;
 }
int main()
 {
 	int i,j,k,l,q,w,e,g[N];char f[N];
 	memset(d,0,sizeof(d));memset(rev,0,sizeof(rev));
 	scanf("%s",f);n=strlen(f);
 	for (i=0;i<n;i++)
 	  d[i*2+1]=int(f[i]);
 	ss=n*2;ans=(M-Manacher())%M;ss--;
 	for (i=0;i<n;i++)
 	  a[i].r=b[i].r=f[i]=='a';
 	QAQ();
 	for (i=0;i<n;i++)
 	  a[i].r=b[i].r=f[i]!='a',a[i].i=b[i].i=0;
 	for (i=n;i<len;i++)
 	  a[i].r=b[i].r=a[i].i=b[i].i=0;
 	for (i=0;i<ss;i++)
 	  g[i]=(eps+c[i].r+(!(i&1)&&f[i/2]=='a'))/2;
 	QAQ();
 	for (i=0;i<ss;i++)
 	  g[i]+=(eps+c[i].r+(!(i&1)&&f[i/2]=='b'))/2;
 	for (i=0;i<ss;i++)
 	  ans=(ans+Quick_mi(2,g[i])-1)%M;
 	cout <<ans<<endl;
 	return 0;
 }

BZOJ3527力:

坑比lydsy没题面,我是到这里看的题。然后发现这个是可以化成卷积形式的,然后就直接FFT了= =

#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
#define N 300050
#define ld long double
#define eps (double)(1e-8)
const double pi=acos(-1.0);
struct complex
 {
    double r,i;
    complex (){}
    complex (ld _r,double _i) {r=_r;i=_i;}
 } a[N],b[N],c[N],f[N];
complex operator+ (const complex &A,const complex &B)
 {return complex(A.r+B.r,A.i+B.i);}
complex operator- (const complex &A,const complex &B)
 {return complex(A.r-B.r,A.i-B.i);}
complex operator* (const complex &A,const complex &B)
 {return complex(A.r*B.r-A.i*B.i,A.r*B.i+A.i*B.r);}
double d[N];
int n,m,len,bit,rev[N];
void FFT(complex* a,int n,int f)
 {
 	  for (int i=0;i<n;i++)
 	    if (i<rev[i]) swap(a[i],a[rev[i]]);
 	  for (int p=1;p<n;p <<= 1)
 	   {
 	      complex wn(cos(pi/p),sin(pi/p)*f);
 	      for (int i=0;i<n;i+=(p << 1))
 	       {
 	     	    complex e(1,0);
 	     	    for (int j=0;j<p;j++,e=e*wn)
 	     	     {
 	     	 	      complex x=a[i+j],y=a[i+j+p]*e;
 	     	      	a[i+j]=x+y;a[i+j+p]=x-y;
 	         	 }
 	       }
 	   }
 	  if (f!=1)
 	   for (int i=0;i<=n;i++)
 	     a[i].r/=n;
 	  return;
 }
int main()
 {
 	  int i,j,k,l,q,w,e;
 	  memset(d,0,sizeof(d));memset(rev,0,sizeof(rev));
 	  cin >>n;bit=0;len=1;
 	  for (i=0;i<n;i++) scanf("%lf",&d[i]),
 	    a[i].r=d[i];
 	  for (i=1;i<n;i++) b[i].r=(double)1/i/i;
 	  while (len<n*2) len <<= 1,bit++;
    for (i=0;i<len;i++)
      rev[i]=(rev[i >> 1] >> 1)|((i&1) << (bit-1));
    FFT(a,len,1);FFT(b,len,1);
    for (i=0;i<len;i++)
      c[i]=a[i]*b[i],c[i].r+=eps;
    FFT(c,len,-1);
    for (i=0;i<len;i++)
      a[i].r=b[i].r=a[i].i=b[i].i=0;
    for (i=0;i<n;i++)
      a[i].r=d[n-i-1],a[i].i=0,b[i].i=0;
    for (i=1;i<n;i++) b[i].r=(double)1/i/i;
    FFT(a,len,1);FFT(b,len,1);
    for (i=0;i<len;i++)
      f[i]=a[i]*b[i];
    FFT(f,len,-1);
    for (i=0;i<n;i++)
      printf("%.3lf\n",c[i].r-f[n-i-1].r);
    return 0;
 }

话说最近的考试题目真是越来越sxbk了,鸿少的题目里面放了FFT也就算了,破太阳的题目里面放了多项式求逆。

等从杭州滚完粗回来就开推Picks的多项式专题

【捂脸贼】好像摆了半个月才开推啊

之前一直不知道有NTT这个东西,于是感觉要取模实在caodan

感觉NTT、多项式逆元、分治FFT也都不难理解的说。。

BZOJ3456城市规划

写的比较丑,反正就是NTT+高能逆元

好像确实时间常数炸翔了,不过在tsinson上面还是过了

复杂度反正算出来是一个log的

#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
#define N 300050
#define M 1004535809
#define G 3
#define ll long long
ll c[N],S[N],T[N],wn[20],jc[N],C[N];
int n,m,len,bit,rev[N];
ll Quick_Power(ll x,int y)
 {
 	ll z=1;
 	while (y)
 	 {
 	 	if (y&1) z=z*x%M;
 	 	x=x*x%M;
 	 	y >>= 1;
 	 }
 	return z;
 }
void NTT(ll* a,int n,int f)
 {
 	for (int i=0;i<n-1;i++)
 	 if (i<rev[i]) swap(a[i],a[rev[i]]);
 	int now=0;
 	for (int p=1;p<n;p <<= 1)
 	 {
 	 	now++;
 	 	for (int i=0;i<n;i+=p << 1)
 	 	 {
 	 	 	ll e=1;
 	 	 	for (int j=0;j<p;j++,e=e*wn[now]%M)
 	 	 	 {
 	 	 	 	ll x=a[i+j],y=e*a[i+j+p]%M;
 	 	 	 	a[i+j]=(x+y)%M;a[i+j+p]=(x-y+M)%M;
 	 	 	 }
 	 	 }
 	 }
 	if (f==-1)
 	 {
 	 	for (int i=1;i<n >> 1;i++) swap(a[i],a[n-i]);
 	 	ll Inv=Quick_Power(n,M-2);
 	    for (int i=0;i<n;i++) a[i]=a[i]*Inv%M;
 	 }
 }
void Solve(int x)
 {
 	if (x==1) {T[0]=Quick_Power(S[0],M-2);return;}
 	Solve(x + 1 >> 1);
 	len=1;bit=0;
 	while (len<x*2-1) len <<= 1,bit++;
 	for (int i=1;i<len;i++)
 	  rev[i]=((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
 	for (int i=0;i<x;i++) c[i]=S[i];
 	for (int i=x;i<len;i++) c[i]=0;
 	NTT(T,len,1);NTT(c,len,1);
    for (int i=0;i<len;i++)
      T[i]=T[i]*(2-T[i]*c[i]%M+M)%M;
    NTT(T,len,-1);
    for (int i=x;i<len;i++) T[i]=0;
    return;
 }
int main()
 {
 	cin >>n;
 	for (int i=0;i<20;i++) wn[i]=Quick_Power(G,M - 1 >> i);
 	jc[0]=1;C[0]=1;
    for (ll i=1,j=1;i<1 << 18;i++,j=j*i%M)
      jc[i]=Quick_Power(j,M-2),C[i]=Quick_Power(2,i*(i-1)/2%(M-1));
    for (int i=-1;i<n;i++) S[n-i-1]=C[n-i-1]*jc[n-i-1]%M;
    Solve(n+1);
 	len=1;bit=0;
 	while (len<n*2-1) len <<= 1,bit++;
 	for (int i=1;i<len;i++)
 	  rev[i]=((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
    for (int i=0;i<n;i++) S[i]=C[i+1]*jc[i]%M;
    NTT(S,len,1);NTT(T,len,1);
    for (int i=0;i<len;i++) c[i]=S[i]*T[i]%M;
    NTT(c,len,-1);
    cout <<c[n-1]*Quick_Power(jc[n-1],M-2)%M<<endl;
    return 0;
 }

然后再附上这题的分治解法,复杂度是log^2 n的,然而在tsinson上面交了一发后面4个点挂时间了

但是感觉这些东西都不算难写,只不过感觉难调

#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
#define N 300050
#define M 1004535809
#define G 3
#define ll long long
ll S[N],T[N],a[N],b[N],c[N],jc[N],cj[N],C[N],wn[20];
int n,m,rev[N];
ll Quick_Power(ll x,int y)
 {
 	ll z=1;
 	while (y)
 	 {
 	 	if (y&1) z=z*x%M;
 	 	x=x*x%M;
 	 	y >>= 1;
 	 }
 	return z;
 }
void NTT(ll* a,int n,int f)
 {
 	for (int i=1;i<n-1;i++)
 	 if (i<rev[i]) swap(a[i],a[rev[i]]);
    int now=1;
    for (int p=1;p<n;p <<= 1,now++)
     {
     	for (int i=0;i<n;i+=p << 1)
     	 {
     	 	ll e=1;
     	 	for (int j=0;j<p;j++,e=e*wn[now]%M)
     	 	 {
     	 	 	ll x=a[i+j],y=a[i+j+p]*e%M;
     	 	 	a[i+j]=(x+y)%M;a[i+j+p]=(x-y+M)%M;
     	 	 }
     	 }
     }
    if (f==-1)
     {
     	for (int i=1;i<n >> 1;i++) swap(a[i],a[n-i]);
     	ll Inv=Quick_Power(n,M-2);
        for (int i=0;i<n;i++) a[i]=a[i]*Inv%M;
     }
 }
void Solve(int x,int y)
 {
 	int z=x + y >> 1,len,bit;
 	if (x==y) {T[x]=(C[x+1]+M-T[x]*cj[x]%M)%M;return;}
    Solve(x,z);
    len=1;bit=0;
    while (len < y - x << 1) len <<= 1,bit++;
    for (int i=1;i<len;i++)
      rev[i]=(rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    for (int i=x;i<=z;i++) a[i-x]=T[i]*jc[i]%M;
    for (int i=z-x+1;i<len;i++) a[i]=0;
    for (int i=0;i<=y-x;i++) b[i]=S[i];
    for (int i=y-x+1;i<len;i++) b[i]=0;
    NTT(a,len,1);NTT(b,len,1);
    for (int i=0;i<len;i++) a[i]=a[i]*b[i]%M;
    NTT(a,len,-1);
    for (int i=z-x+1;i<=y-x;i++) T[i+x]=(T[i+x]+a[i])%M;
    Solve(z+1,y);
    return;
 }
int main()
 {
 	cin >>n;
 	for (int i=0;i<20;i++) wn[i]=Quick_Power(G,M - 1 >> i);
 	jc[0]=C[0]=1;cj[0]=1;
 	for (ll i=1;i<1 << 18;i++)
 	  jc[i]=Quick_Power(cj[i]=i*cj[i-1]%M,M-2),
 	  C[i]=Quick_Power(2,i*(i-1)/2%(M-1));
 	for (ll i=0;i<n;i++) S[i]=C[i]*jc[i]%M;
    Solve(0,n-1);
    cout <<T[n-1]%M<<endl;
    return 0;
 }

登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter