gwsnr.mlx.mlx_functions
MLX-compiled functions for gravitational wave data analysis.
This module provides high-performance MLX implementations of core functions used in gravitational wave signal-to-noise ratio (SNR) calculations and parameter estimation. Key features include:
Chirp time calculations using 3.5 post-Newtonian approximations
Antenna response pattern computations for gravitational wave detectors
Polarization tensor calculations for plus and cross modes
Coordinate transformations between celestial and detector frames
Vectorized operations for efficient batch processing
Automatic parallelization through MLX’s vmap for multi-dimensional arrays
All functions are compiled with MLX’s @mx.compile decorator for optimal performance and efficient computation on Apple silicon’s unified memory. The implementations are optimized for use in Bayesian inference pipelines and matched filtering applications in gravitational wave astronomy.
Module Contents
Functions
|
Function to calculate the chirp time from minimum frequency to last stable orbit (MLX implementation). |
|
|
|
|
|
Function to convert GPS time to Greenwich Mean Sidereal Time (GMST) (JAX implementation). |
|
Function to convert right ascension and declination to spherical coordinates (JAX implementation). |
|
Function to calculate the plus polarization tensor for gravitational wave detection (JAX implementation). |
|
Function to calculate the cross polarization tensor for gravitational wave detection (JAX implementation). |
|
Function to calculate the plus polarization antenna response for gravitational wave detection (JAX implementation). |
|
Function to calculate the cross polarization antenna response for gravitational wave detection (JAX implementation). |
|
Function to calculate the antenna response for multiple detectors and sources (JAX implementation). |
- gwsnr.mlx.mlx_functions.findchirp_chirptime_mlx(m1, m2, fmin)[source]
Function to calculate the chirp time from minimum frequency to last stable orbit (MLX implementation).
Time taken from f_min to f_lso (last stable orbit). 3.5PN in fourier phase considered.
- Parameters:
- m1float or mx.array
Mass of the first body in solar masses.
- m2float or mx.array
Mass of the second body in solar masses.
- fminfloat or mx.array
Lower frequency cutoff in Hz.
- Returns:
- chirp_timemx.array
Time taken from f_min to f_lso (last stable orbit frequency) in seconds.
Notes
Calculates chirp time using 3.5PN approximation for gravitational wave Fourier phase. The time represents frequency evolution from fmin to last stable orbit frequency. Uses post-Newtonian expansion coefficients optimized for efficient MLX computation. MLX implementation supports JIT compilation.
- gwsnr.mlx.mlx_functions.gps_to_gmst(gps_time)[source]
Function to convert GPS time to Greenwich Mean Sidereal Time (GMST) (JAX implementation).
- Parameters:
- gps_timefloat
GPS time in seconds.
- Returns:
- gmstfloat
Greenwich Mean Sidereal Time in radians.
Notes
Uses a linear approximation with a reference time and slope to compute GMST. The reference time (time0) is 1126259642.413 seconds and the slope is 7.292115855382993e-05 radians per second, which approximates Earth’s rotation rate. JAX implementation supports automatic differentiation for gradient-based optimization.
- gwsnr.mlx.mlx_functions.ra_dec_to_theta_phi(ra, dec, gmst)[source]
Function to convert right ascension and declination to spherical coordinates (JAX implementation).
- Parameters:
- rafloat
Right ascension of the source in radians.
- decfloat
Declination of the source in radians.
- gmstfloat
Greenwich Mean Sidereal Time in radians.
- Returns:
- thetafloat
Polar angle (colatitude) in radians, measured from the north pole.
- phifloat
Azimuthal angle in radians, adjusted for Earth’s rotation.
Notes
Converts celestial coordinates (ra, dec) to spherical coordinates (theta, phi) in the detector frame. The azimuthal angle is corrected for Earth’s rotation using GMST. Theta represents the angle from the north pole (colatitude). JAX implementation provides automatic differentiation capabilities for parameter estimation and optimization workflows.
- gwsnr.mlx.mlx_functions.get_polarization_tensor_plus(ra, dec, time, psi)[source]
Function to calculate the plus polarization tensor for gravitational wave detection (JAX implementation).
- Parameters:
- rafloat
Right ascension of the source in radians.
- decfloat
Declination of the source in radians.
- timefloat
GPS time of the source in seconds.
- psifloat
Polarization angle of the source in radians.
- Returns:
- polarization_tensor_plusjax.numpy.ndarray
3x3 plus polarization tensor matrix (m⊗m - n⊗n).
Notes
Calculates the plus polarization tensor in the detector frame by first converting celestial coordinates to spherical coordinates using GMST, then computing the basis vectors m and n based on the polarization angle psi. Returns the tensor m⊗m - n⊗n for plus polarization mode. JAX implementation supports automatic differentiation and GPU acceleration for efficient computation.
- gwsnr.mlx.mlx_functions.get_polarization_tensor_cross(ra, dec, time, psi)[source]
Function to calculate the cross polarization tensor for gravitational wave detection (JAX implementation).
- Parameters:
- rafloat
Right ascension of the source in radians.
- decfloat
Declination of the source in radians.
- timefloat
GPS time of the source in seconds.
- psifloat
Polarization angle of the source in radians.
- Returns:
- polarization_tensor_crossjax.numpy.ndarray
3x3 cross polarization tensor matrix (m⊗n + n⊗m).
Notes
Calculates the cross polarization tensor in the detector frame by first converting celestial coordinates to spherical coordinates using GMST, then computing the basis vectors m and n based on the polarization angle psi. Returns the tensor m⊗n + n⊗m for cross polarization mode. JAX implementation supports automatic differentiation and GPU acceleration for efficient computation.
- gwsnr.mlx.mlx_functions.antenna_response_plus(ra, dec, time, psi, detector_tensor)[source]
Function to calculate the plus polarization antenna response for gravitational wave detection (JAX implementation).
- Parameters:
- rafloat
Right ascension of the source in radians.
- decfloat
Declination of the source in radians.
- timefloat
GPS time of the source in seconds.
- psifloat
Polarization angle of the source in radians.
- detector_tensorjax.numpy.ndarray
Detector tensor for the detector (3x3 matrix).
- Returns:
- antenna_response_plusfloat
Plus polarization antenna response of the detector.
Notes
Computes the plus polarization antenna response by calculating the Frobenius inner product between the detector tensor and the plus polarization tensor. The polarization tensor is determined by the source location (ra, dec), observation time, and polarization angle (psi). JAX implementation provides automatic differentiation for parameter estimation workflows.
- gwsnr.mlx.mlx_functions.antenna_response_cross(ra, dec, time, psi, detector_tensor)[source]
Function to calculate the cross polarization antenna response for gravitational wave detection (JAX implementation).
- Parameters:
- rafloat
Right ascension of the source in radians.
- decfloat
Declination of the source in radians.
- timefloat
GPS time of the source in seconds.
- psifloat
Polarization angle of the source in radians.
- detector_tensorjax.numpy.ndarray
Detector tensor for the detector (3x3 matrix).
- Returns:
- antenna_response_crossfloat
Cross polarization antenna response of the detector.
Notes
Computes the cross polarization antenna response by calculating the Frobenius inner product between the detector tensor and the cross polarization tensor. The polarization tensor is determined by the source location (ra, dec), observation time, and polarization angle (psi). JAX implementation provides automatic differentiation for parameter estimation workflows.
- gwsnr.mlx.mlx_functions.antenna_response_array(ra, dec, time, psi, detector_tensor)[source]
Function to calculate the antenna response for multiple detectors and sources (JAX implementation).
- Parameters:
- rajax.numpy.ndarray
Array of right ascension values for sources in radians.
- decjax.numpy.ndarray
Array of declination values for sources in radians.
- timejax.numpy.ndarray
Array of GPS times for sources in seconds.
- psijax.numpy.ndarray
Array of polarization angles for sources in radians.
- detector_tensorjax.numpy.ndarray
Detector tensor array for multiple detectors (n×3×3 matrix), where n is the number of detectors.
- Returns:
- Fpjax.numpy.ndarray
Plus polarization antenna response array with shape (n_detectors, n_sources).
- Fcjax.numpy.ndarray
Cross polarization antenna response array with shape (n_detectors, n_sources).
Notes
Computes antenna responses for both plus and cross polarizations across multiple detectors and source parameters simultaneously. Uses JAX’s vmap for efficient vectorized computation with automatic differentiation support. Each antenna response is calculated using the Frobenius inner product between detector tensors and polarization tensors derived from source sky location and polarization angle. Optimized for GPU acceleration and gradient-based optimization.