FFT Code
Recursion version
use num::complex::Complex;
use std::f64::consts::PI;
// Base function for both FFT and IFFT
fn fft_base(coeffs: Vec<Complex<f64>>, is_ifft: bool) -> Vec<Complex<f64>> {
let n = coeffs.len();
if n == 1 {
return coeffs;
}
// Split the input into even and odd indices
let even_coeffs: Vec<Complex<f64>> = coeffs.iter().step_by(2).cloned().collect();
let odd_coeffs: Vec<Complex<f64>> = coeffs.iter().skip(1).step_by(2).cloned().collect();
// Calculate the twiddle factor (omega)
// For IFFT, use the conjugate of the normal FFT twiddle factor
let theta = if is_ifft { -2.0 * PI / n as f64 } else { 2.0 * PI / n as f64 };
let omega = Complex::from_polar(1.0, theta);
// Recursive calls for even and odd parts
let y_even = fft_base(even_coeffs, is_ifft);
let y_odd = fft_base(odd_coeffs, is_ifft);
// Combine the results using the butterfly operation
let mut y = vec![Complex::new(0.0, 0.0); n];
let mut current_omega = Complex::new(1.0, 0.0);
for k in 0..n/2 {
let t = current_omega * y_odd[k];
y[k] = y_even[k] + t;
y[k + n/2] = y_even[k] - t;
current_omega *= omega;
}
y
}
// Fast Fourier Transform (FFT) function
fn fft(coeffs: Vec<Complex<f64>>) -> Vec<Complex<f64>> {
fft_base(coeffs, false)
}
// Inverse Fast Fourier Transform (IFFT) function
fn ifft(coeffs: Vec<Complex<f64>>) -> Vec<Complex<f64>> {
let n = coeffs.len() as f64;
// Perform IFFT and normalize the result by dividing by n
fft_base(coeffs, true).iter().map(|x| x / n).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fft_ifft() {
// create input data
let input: Vec<Complex<f64>> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]
.into_iter()
.map(|x| Complex::new(x, 0.0))
.collect();
// execute FFT
let fft_result = fft(input.clone());
println!("fft_result: {:?}", fft_result);
// execute IFFT
let ifft_result = ifft(fft_result);
println!("ifft_result: {:?}", ifft_result);
// verify IFFT result is close to original input
for (original, result) in input.iter().zip(ifft_result.iter()) {
assert!((original - result).norm() < 1e-10, "IFFT result does not match original input");
}
}
}
Iterative version
use num::complex::Complex;
use std::f64::consts::PI;
// Fast Fourier Transform (FFT) function
fn fft(mut input: Vec<Complex<f64>>) -> Vec<Complex<f64>> {
let n = input.len();
if n <= 1 {
return input;
}
// Bit-reversal permutation
for i in 0..n {
let j = reverse_bits(i, n);
if i < j {
input.swap(i, j);
}
}
// Cooley-Tukey FFT algorithm (butterfly operations)
let mut step = 1;
while step < n {
let jump = step * 2;
let omega = Complex::from_polar(1.0, -PI / step as f64);
let mut w = Complex::new(1.0, 0.0);
for group in 0..step {
for pair in (group..n).step_by(jump) {
let t = input[pair + step] * w;
input[pair + step] = input[pair] - t;
input[pair] = input[pair] + t;
}
w *= omega;
}
step *= 2;
}
input
}
// Inverse Fast Fourier Transform (IFFT) function
fn ifft(mut input: Vec<Complex<f64>>) -> Vec<Complex<f64>> {
let n = input.len();
if n <= 1 {
return input;
}
// Bit-reversal permutation
for i in 0..n {
let j = reverse_bits(i, n);
if i < j {
input.swap(i, j);
}
}
// Cooley-Tukey IFFT algorithm (butterfly operations)
let mut step = 1;
while step < n {
let jump = step * 2;
let omega = Complex::from_polar(1.0, PI / step as f64); // Note the sign change compared to FFT
let mut w = Complex::new(1.0, 0.0);
for group in 0..step {
for pair in (group..n).step_by(jump) {
let t = input[pair + step] * w;
input[pair + step] = input[pair] - t;
input[pair] = input[pair] + t;
}
w *= omega;
}
step *= 2;
}
// Normalization
let scale = 1.0 / n as f64;
input.iter_mut().for_each(|x| *x *= scale);
input
}
// Bit reversal function
fn reverse_bits(mut num: usize, bit_count: usize) -> usize {
let mut reversed = 0;
for _ in 0..bit_count.trailing_zeros() {
reversed = (reversed << 1) | (num & 1);
num >>= 1;
}
reversed
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fft_ifft() {
// create input data
let input: Vec<Complex<f64>> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]
.into_iter()
.map(|x| Complex::new(x, 0.0))
.collect();
// execute FFT
let fft_result = fft(input.clone());
println!("fft_result: {:?}", fft_result);
// execute IFFT
let ifft_result = ifft(fft_result);
println!("ifft_result: {:?}", ifft_result);
// verify IFFT result is close to original input
for (original, result) in input.iter().zip(ifft_result.iter()) {
assert!((original - result).norm() < 1e-10, "IFFT result does not match original input");
}
}
}