ICPC Competitive Programming

FFT/NTT

Fast Fourier Transform และ Number Theoretic Transform สำหรับ Polynomial Multiplication

FFT/NTT

Fast Fourier Transform (FFT) และ Number Theoretic Transform (NTT) ใช้สำหรับ polynomial multiplication ใน O(nlogn)O(n \log n)

Polynomial Multiplication

การคูณ polynomials A(x)A(x) และ B(x)B(x) degree nn:

  • Naive: O(n2)O(n^2)
  • FFT/NTT: O(nlogn)O(n \log n)

Fast Fourier Transform (FFT)

ใช้ complex numbers และ roots of unity

Cooley-Tukey FFT

using cd = complex<double>;
const double PI = acos(-1);

void fft(vector<cd>& a, bool invert) {
    int n = a.size();
    if (n == 1) return;
    
    // Bit reversal permutation
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) {
            j ^= bit;
        }
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }
    
    // Iterative FFT
    for (int len = 2; len <= n; len <<= 1) {
        double ang = 2 * PI / len * (invert ? -1 : 1);
        cd wlen(cos(ang), sin(ang));
        
        for (int i = 0; i < n; i += len) {
            cd w(1);
            for (int j = 0; j < len / 2; j++) {
                cd u = a[i + j];
                cd v = a[i + j + len/2] * w;
                a[i + j] = u + v;
                a[i + j + len/2] = u - v;
                w *= wlen;
            }
        }
    }
    
    if (invert) {
        for (cd& x : a) x /= n;
    }
}

vector<long long> multiply(vector<long long>& a, vector<long long>& b) {
    vector<cd> fa(a.begin(), a.end()), fb(b.begin(), b.end());
    
    int n = 1;
    while (n < a.size() + b.size()) n <<= 1;
    fa.resize(n);
    fb.resize(n);
    
    fft(fa, false);
    fft(fb, false);
    
    for (int i = 0; i < n; i++) {
        fa[i] *= fb[i];
    }
    
    fft(fa, true);
    
    vector<long long> result(n);
    for (int i = 0; i < n; i++) {
        result[i] = round(fa[i].real());
    }
    
    // Remove leading zeros
    while (result.size() > 1 && result.back() == 0) {
        result.pop_back();
    }
    
    return result;
}

Number Theoretic Transform (NTT)

ใช้ modular arithmetic แทน complex numbers - ไม่มี floating point errors

Requirements

MOD ต้องเป็น prime ในรูป MOD=c2k+1MOD = c \cdot 2^k + 1 โดยที่ 2kn2^k \geq n

Common NTT-friendly primes:

  • 998244353=119223+1998244353 = 119 \cdot 2^{23} + 1 (primitive root = 3)
  • 754974721=45224+1754974721 = 45 \cdot 2^{24} + 1 (primitive root = 11)
  • 167772161=5225+1167772161 = 5 \cdot 2^{25} + 1 (primitive root = 3)

Implementation

const long long MOD = 998244353;
const long long g = 3;  // primitive root

long long power(long long a, long long b, long long mod) {
    long long res = 1;
    a %= mod;
    while (b > 0) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

void ntt(vector<long long>& a, bool invert) {
    int n = a.size();
    
    // Bit reversal
    for (int i = 1, j = 0; i < n; i++) {
        int bit = n >> 1;
        for (; j & bit; bit >>= 1) {
            j ^= bit;
        }
        j ^= bit;
        if (i < j) swap(a[i], a[j]);
    }
    
    // NTT
    for (int len = 2; len <= n; len <<= 1) {
        long long w = invert ? power(g, MOD - 1 - (MOD - 1) / len, MOD) 
                             : power(g, (MOD - 1) / len, MOD);
        
        for (int i = 0; i < n; i += len) {
            long long wn = 1;
            for (int j = 0; j < len / 2; j++) {
                long long u = a[i + j];
                long long v = a[i + j + len/2] * wn % MOD;
                a[i + j] = (u + v) % MOD;
                a[i + j + len/2] = (u - v + MOD) % MOD;
                wn = wn * w % MOD;
            }
        }
    }
    
    if (invert) {
        long long n_inv = power(n, MOD - 2, MOD);
        for (long long& x : a) {
            x = x * n_inv % MOD;
        }
    }
}

vector<long long> multiplyNTT(vector<long long> a, vector<long long> b) {
    int n = 1;
    while (n < a.size() + b.size()) n <<= 1;
    a.resize(n);
    b.resize(n);
    
    ntt(a, false);
    ntt(b, false);
    
    for (int i = 0; i < n; i++) {
        a[i] = a[i] * b[i] % MOD;
    }
    
    ntt(a, true);
    
    while (a.size() > 1 && a.back() == 0) {
        a.pop_back();
    }
    
    return a;
}

Applications

1. Large Number Multiplication

string multiplyStrings(string num1, string num2) {
    vector<long long> a, b;
    for (int i = num1.size() - 1; i >= 0; i--) {
        a.push_back(num1[i] - '0');
    }
    for (int i = num2.size() - 1; i >= 0; i--) {
        b.push_back(num2[i] - '0');
    }
    
    vector<long long> c = multiply(a, b);
    
    // Handle carries
    long long carry = 0;
    for (int i = 0; i < c.size(); i++) {
        c[i] += carry;
        carry = c[i] / 10;
        c[i] %= 10;
    }
    while (carry) {
        c.push_back(carry % 10);
        carry /= 10;
    }
    
    string result;
    for (int i = c.size() - 1; i >= 0; i--) {
        result += ('0' + c[i]);
    }
    
    // Remove leading zeros
    int start = 0;
    while (start < result.size() - 1 && result[start] == '0') start++;
    return result.substr(start);
}

2. Counting Sum of Pairs

นับจำนวนวิธีที่ได้ผลรวม kk จากการเลือก 1 ตัวจากแต่ละ array

vector<long long> countSums(vector<int>& a, vector<int>& b, int maxVal) {
    vector<long long> freqA(maxVal + 1), freqB(maxVal + 1);
    
    for (int x : a) freqA[x]++;
    for (int x : b) freqB[x]++;
    
    return multiplyNTT(freqA, freqB);
    // result[k] = number of ways to get sum k
}

3. String Matching with Wildcards

// Pattern matching where '?' matches any character
vector<int> wildcardMatch(string& text, string& pattern) {
    // Convert to polynomial multiplication
    // ...
}

4. Polynomial Division

vector<long long> polyDiv(vector<long long> a, vector<long long> b) {
    // Use Newton's method with NTT
    // ...
}

FFT vs NTT Comparison

AspectFFTNTT
PrecisionFloating point errorsExact
ModularNoYes
SpeedFaster constantsSlightly slower
Use caseGeneral multiplicationModular problems

Complexity

OperationTimeSpace
FFT/NTTO(nlogn)O(n \log n)O(n)O(n)
Polynomial multiplicationO(nlogn)O(n \log n)O(n)O(n)

Tips

  1. Power of 2 - pad arrays to next power of 2
  2. Precision - FFT อาจมี error, ใช้ round() หรือเปลี่ยนเป็น NTT
  3. Multiple MODs - สำหรับ results ที่ใหญ่มาก ใช้ CRT กับหลาย NTT primes
  4. Precompute roots - เร่งความเร็วโดย precompute primitive roots