数学板块学习之NTT

    xiaoxiao2022-07-07  206

    参考博客: https://blog.csdn.net/caoguo_app_android/article/details/44067483

    任意模数NTT(三模数NTT)

    NTT

    快速数论变换(Number-Theoretic Transform,NTT),实际上就是模意义下的FFT。 所以在看NTT前最好看一下上一篇FFT

    因为FFT在运算时是用的复数,涉及到了double,所以不可避免的会出现一定的精度问题。这个问题是最严重的,而其他的问题FFT也会有,比如数过大时,答案可能会超级大,所以取模又成了问题。

    首先FFT能够使用就是依靠 ω n \omega_n ωn的性质,而在数论方面,我们可以找到一个性质与 ω \omega ω十分类似的东西 这些性质包括: ω n 0 = ω n n = 1 \omega_{n}^{0}=\omega_n^n=1 ωn0=ωnn=1 ω n 0 , ω n 1 , ⋯   , ω n n − 1 \omega_n^0,\omega_n^1,\cdots,\omega_n^{n-1} ωn0,ωn1,,ωnn1互不相同(可以想象或者看FFT里面对单位根的介绍中说得欧拉恒等式,结合单位圆的图形可以更好的接受这个性质) ω k n k i = ω n i \omega_{kn}^{ki}=\omega_n^i ωknki=ωni 在数论里,我们可以找到原根

    原根

    对于素数 p p p的原根 g g g定义为 ∀ 0 ≤ i < j < p − 1 \forall 0\leq i<j< p-1 0i<j<p1满足 g i = ̸ g j   ( m o d   p ) g^i=\not g^j (mod \ p) gi≠gj (mod p) 素数一定存在原根

    我们类比FFT,用原根 g n 1 = g p − 1 n g_n^1=g^{\frac{p-1}{n}} gn1=gnp1代替 ω n 1 \omega_n^{1} ωn1 FFT: y k = ∑ i = 0 n − 1 c i ( ω n k ) i y_k=\sum_{i=0}^{n-1}c_i(\omega_n^k)^i yk=i=0n1ci(ωnk)i NTT: y k = ∑ i = 0 n − 1 c i ( g n k ) i ( m o d   P ) y_k=\sum_{i=0}^{n-1}c_i(g_n^k)^i(mod\ P) yk=i=0n1ci(gnk)i(mod P) IFFT: c k = 1 n ∑ i = 0 n − 1 y i ( ω n − k ) i c_k=\cfrac{1}{n}\sum_{i=0}^{n-1}y_i(\omega_n^{-k})^i ck=n1i=0n1yi(ωnk)i INTT c k = 1 n ∑ i = 0 n − 1 y i ( g n − k ) i ( m o d   P ) c_k=\frac{1}{n}\sum_{i=0}^{n-1}y_i(g_n^{-k})^i(mod\ P) ck=n1i=0n1yi(gnk)i(mod P)

    通过这种对应,我们将复数转换为了整数求解。 对于NTT我们经常选择的模为 469762049 , 998244353 , 1004535809 469762049,998244353,1004535809 469762049,998244353,1004535809。其原根都为3

    NTT求大数乘法模板

    #include <iostream> #include <algorithm> #include <cstdio> #include <queue> #include <cmath> #include <string> #include <cstring> #include <map> #include <set> #include <cmath> #include <tr1/unordered_map> using namespace std; #define me(x,y) memset(x,y,sizeof x) #define MIN(x,y) x < y ? x : y #define MAX(x,y) x > y ? x : y typedef long long ll; typedef unsigned long long ull; const double eps = 1e-08; const double PI = acos(-1.0); const int N = 1<<18; const int P = (479<<21)+1; const int G = 3; const int NUM = 20; ll qpow(ll a,ll b,ll p){ ll ans = 1; a %= p; while(b){ if(b&1) ans = (ans*a)%p; b >>= 1; a = (a*a)%p; } return ans; } ll wn[NUM],a[N],b[N]; char A[N],B[N]; void getwn(){ for(int i = 0; i < NUM; ++i){ int t = 1<<i; wn[i] = qpow(G,(P-1)/t,P); } } void Rader(ll a[],int len){ int j = len>>1; for(int i = 1;i < len-1; ++i){ if(i < j) swap(a[i],a[j]); int k = len>>1; while(j >= k){ j -= k; k >>= 1; } if(j < k) j += k; } } void ntt(ll a[],int len,int on){ Rader(a,len); int id = 0; for(int h = 2; h <= len;h <<= 1){ id++; for(int j = 0 ; j < len; j += h){ ll w = 1; for(int k = j; k < j+h/2; ++k){ ll u = a[k]%P; ll t = w*a[k+h/2]%P; a[k] = (u+t)%P; a[k+h/2]=(u-t+P)%P; w = w*wn[id]%P; } } } if(on == -1){ for(int i = 1; i < len/2; ++i){ swap(a[i],a[len-i]); } ll inv = qpow(len,P-2,P); for(int i = 0; i < len; ++i){ a[i] = a[i]*inv%P; } } } int main(){ getwn(); int n; cin>>n; scanf("%s%s",A,B); int len = 1; int len1 = strlen(A); int len2 = strlen(B); while(len <= len1*2 || len <= len2*2) len <<= 1; for(int i = 0; i < len1; ++i) a[i] = A[len1-i-1]-'0'; for(int i = len1; i < len; ++i) a[i] = 0; for(int i = 0; i < len2; ++i) b[i] = B[len2-i-1]-'0'; for(int i = len2; i < len; ++i) b[i] = 0; ntt(a,len,1); ntt(b,len,1); for(int i = 0; i < len; ++i) a[i] = a[i]*b[i]%P; ntt(a,len,-1); for(int i = 0; i < len; ++i){ a[i+1] += a[i]/10; a[i] %= 10; } len = len1+len2-1; while(len > 0 && a[len] <= 0) len--; for(int i = len; i >= 0; --i){ printf("%c",(int)a[i]+'0'); } cout<<endl; return 0; } /* */
    最新回复(0)