# -*- coding: utf-8 -*-
"""
Numba-compiled helper functions for gravitational wave signal-to-noise ratio calculations.
This module provides optimized numerical functions for gravitational wave data analysis,
including chirp time calculations, antenna response computations, polarization tensors,
coordinate transformations, and noise-weighted inner products. All functions are compiled
with Numba's @njit decorator for high-performance computation, with parallel processing
support using prange for multi-threaded execution where applicable.
"""
# -*- coding: utf-8 -*-
"""
Helper functions for gwsnr. All functions are njit compiled.
"""
import numpy as np
from numba import njit, prange
[docs]
Gamma = 0.5772156649015329
[docs]
MTSUN_SI = 4.925491025543576e-06
@njit
[docs]
def findchirp_chirptime(m1, m2, fmin):
"""
Time taken from f_min to f_lso (last stable orbit). 3.5PN in fourier phase considered.
Parameters
----------
m1 : `float`
Mass of the first body in solar masses.
m2 : `float`
Mass of the second body in solar masses.
fmin : `float`
Lower frequency cutoff.
Returns
-------
chirp_time : float
Time taken from f_min to f_lso (last stable orbit frequency).
"""
# variables used to compute chirp time
m = m1 + m2
eta = m1 * m2 / m / m
c0T = c2T = c3T = c4T = c5T = c6T = c6LogT = c7T = 0.0
c7T = Pi * (
14809.0 * eta * eta / 378.0 - 75703.0 * eta / 756.0 - 15419335.0 / 127008.0
)
c6T = (
Gamma * 6848.0 / 105.0
- 10052469856691.0 / 23471078400.0
+ Pi * Pi * 128.0 / 3.0
+ eta * (3147553127.0 / 3048192.0 - Pi * Pi * 451.0 / 12.0)
- eta * eta * 15211.0 / 1728.0
+ eta * eta * eta * 25565.0 / 1296.0
+ eta * eta * eta * 25565.0 / 1296.0
+ np.log(4.0) * 6848.0 / 105.0
)
c6LogT = 6848.0 / 105.0
c5T = 13.0 * Pi * eta / 3.0 - 7729.0 * Pi / 252.0
c4T = 3058673.0 / 508032.0 + eta * (5429.0 / 504.0 + eta * 617.0 / 72.0)
c3T = -32.0 * Pi / 5.0
c2T = 743.0 / 252.0 + eta * 11.0 / 3.0
c0T = 5.0 * m * MTSUN_SI / (256.0 * eta)
# This is the PN parameter v evaluated at the lower freq. cutoff
xT = np.power(Pi * m * MTSUN_SI * fmin, 1.0 / 3.0)
x2T = xT * xT
x3T = xT * x2T
x4T = x2T * x2T
x5T = x2T * x3T
x6T = x3T * x3T
x7T = x3T * x4T
x8T = x4T * x4T
# Computes the chirp time as tC = t(v_low)
# tC = t(v_low) - t(v_upper) would be more
# correct, but the difference is negligble.
return (
c0T
* (
1
+ c2T * x2T
+ c3T * x3T
+ c4T * x4T
+ c5T * x5T
+ (c6T + c6LogT * np.log(xT)) * x6T
+ c7T * x7T
)
/ x8T
)
@njit
[docs]
def einsum1(m,n):
"""
Function to calculate einsum of two 3x1 vectors
Parameters
----------
m : `numpy.ndarray`
3x1 vector.
n : `numpy.ndarray`
3x1 vector.
Returns
-------
ans : `numpy.ndarray`
3x3 matrix.
"""
ans = np.zeros((3,3))
ans[0,0] = m[0]*n[0]
ans[0,1] = m[0]*n[1]
ans[0,2] = m[0]*n[2]
ans[1,0] = m[1]*n[0]
ans[1,1] = m[1]*n[1]
ans[1,2] = m[1]*n[2]
ans[2,0] = m[2]*n[0]
ans[2,1] = m[2]*n[1]
ans[2,2] = m[2]*n[2]
return ans
@njit
[docs]
def einsum2(m,n):
"""
Function to calculate einsum of two 3x3 matrices
Parameters
----------
m : `numpy.ndarray`
3x3 matrix.
n : `numpy.ndarray`
3x3 matrix.
Returns
-------
ans : `numpy.ndarray`
3x3 matrix.
"""
ans = m[0,0]*n[0,0] + m[0,1]*n[0,1] + m[0,2]*n[0,2] + m[1,0]*n[1,0] + m[1,1]*n[1,1] + m[1,2]*n[1,2] + m[2,0]*n[2,0] + m[2,1]*n[2,1] + m[2,2]*n[2,2]
return ans
@njit
[docs]
def gps_to_gmst(gps_time):
"""
Function to convert gps time to greenwich mean sidereal time
Parameters
----------
gps_time : `float`
GPS time in seconds.
Returns
-------
gmst : `float`
Greenwich mean sidereal time in radians.
"""
slope = 7.292115855425873e-05
intercept = -45991.08966925838
return slope*gps_time+intercept
@njit
[docs]
def ra_dec_to_theta_phi(ra, dec, gmst):
"""
Function to convert ra and dec to theta and phi
Parameters
----------
ra : `float`
Right ascension of the source in radians.
dec : `float`
Declination of the source in radians.
gmst : `float`
Greenwich mean sidereal time in radians.
Returns
-------
theta : `float`
Polar angle in radians.
phi : `float`
Azimuthal angle in radians.
"""
phi = ra - gmst
theta = np.pi / 2.0 - dec
return theta, phi
@njit
[docs]
def get_polarization_tensor_plus(ra, dec, time, psi):
"""
Function to calculate the polarization tensor
Parameters
----------
ra : `float`
Right ascension of the source in radians.
dec : float
Declination of the source in radians.
time : `float`
GPS time of the source.
psi : `float`
Polarization angle of the source.
Returns
-------
polarization_tensor: `numpy.ndarray`
Polarization tensor of the detector.
"""
gmst = np.fmod(gps_to_gmst(time), 2 * np.pi)
theta, phi = ra_dec_to_theta_phi(ra, dec, gmst)
u = np.array([np.cos(phi) * np.cos(theta), np.cos(theta) * np.sin(phi), -np.sin(theta)])
v = np.array([-np.sin(phi), np.cos(phi), 0])
m = -u * np.sin(psi) - v * np.cos(psi)
n = -u * np.cos(psi) + v * np.sin(psi)
return einsum1(m, m) - einsum1(n, n)
@njit
[docs]
def get_polarization_tensor_cross(ra, dec, time, psi):
"""
Function to calculate the polarization tensor
Parameters
----------
ra : `float`
Right ascension of the source in radians.
dec : float
Declination of the source in radians.
time : `float`
GPS time of the source.
psi : `float`
Polarization angle of the source.
Returns
-------
polarization_tensor: `numpy.ndarray`
Polarization tensor of the detector.
"""
gmst = np.fmod(gps_to_gmst(time), 2 * np.pi)
theta, phi = ra_dec_to_theta_phi(ra, dec, gmst)
u = np.array([np.cos(phi) * np.cos(theta), np.cos(theta) * np.sin(phi), -np.sin(theta)])
v = np.array([-np.sin(phi), np.cos(phi), 0])
m = -u * np.sin(psi) - v * np.cos(psi)
n = -u * np.cos(psi) + v * np.sin(psi)
return einsum1(m, n) + einsum1(n, m)
@njit
[docs]
def antenna_response_plus(ra, dec, time, psi, detector_tensor):
"""
Function to calculate the antenna response
Parameters
----------
ra : `float`
Right ascension of the source in radians.
dec : float
Declination of the source in radians.
time : `float`
GPS time of the source.
psi : `float`
Polarization angle of the source.
detector_tensor : array-like
Detector tensor for the detector (3x3 matrix)
mode : `str`
Mode of the polarization. Default is 'plus'.
Returns
-------
antenna_response: `float`
Antenna response of the detector.
"""
polarization_tensor = get_polarization_tensor_plus(ra, dec, time, psi)
return einsum2(detector_tensor, polarization_tensor)
@njit
[docs]
def antenna_response_cross(ra, dec, time, psi, detector_tensor):
"""
Function to calculate the antenna response
Parameters
----------
ra : `float`
Right ascension of the source in radians.
dec : float
Declination of the source in radians.
time : `float`
GPS time of the source.
psi : `float`
Polarization angle of the source.
detector_tensor : array-like
Detector tensor for the detector (3x3 matrix)
mode : `str`
Mode of the polarization. Default is 'plus'.
Returns
-------
antenna_response: `float`
Antenna response of the detector.
"""
polarization_tensor = get_polarization_tensor_cross(ra, dec, time, psi)
return einsum2(detector_tensor, polarization_tensor)
@njit(parallel=True)
[docs]
def antenna_response_array(ra, dec, time, psi, detector_tensor):
"""
Function to calculate the antenna response in array form.
Parameters
----------
ra : `numpy.ndarray`
Right ascension of the source in radians.
dec : `numpy.ndarray`
Declination of the source in radians.
time : `numpy.ndarray`
GPS time of the source.
psi : `numpy.ndarray`
Polarization angle of the source.
detector_tensor : array-like
Detector tensor for the multiple detectors (nx3x3 matrix), where n is the number of detectors.
Returns
-------
antenna_response: `numpy.ndarray`
Antenna response of the detector. Shape is (n, len(ra)).
"""
len_det = len(detector_tensor)
len_param = len(ra)
Fp = np.zeros((len_det, len_param))
Fc = np.zeros((len_det, len_param))
for i in prange(len_param):
for j in range(len_det):
Fp[j,i] = antenna_response_plus(ra[i], dec[i], time[i], psi[i], detector_tensor[j])
Fc[j,i] = antenna_response_cross(ra[i], dec[i], time[i], psi[i], detector_tensor[j])
return Fp, Fc
@njit
[docs]
def effective_distance(
luminosity_distance, theta_jn, ra, dec, geocent_time, psi, detector_tensor
):
"""
Function to calculate the effective distance of the source.
Parameters
----------
luminosity_distance : `float`
Luminosity distance of the source in Mpc.
theta_jn : `float`
Angle between the line of sight and the orbital angular momentum vector.
ra : `float`
Right ascension of the source in radians.
dec : `float`
Declination of the source in radians.
time : `float`
GPS time of the source.
psi : `float`
Polarization angle of the source.
detector_tensor : array-like
Detector tensor for the detector (3x3 matrix).
Returns
-------
effective_distance: `float`
Effective distance of the source in Mpc.
"""
Fp, Fc = antenna_response_plus(ra, dec, geocent_time, psi, detector_tensor), antenna_response_cross(ra, dec, geocent_time, psi, detector_tensor)
return luminosity_distance / np.sqrt(
Fp**2 * ((1 + np.cos(theta_jn) ** 2) / 2) ** 2
+ Fc**2 * np.cos(theta_jn) ** 2
)
@njit(parallel=True)
[docs]
def effective_distance_array(
luminosity_distance, theta_jn, ra, dec, geocent_time, psi, detector_tensor
):
"""
Function to calculate the effective distance of the source in array form.
Parameters
----------
luminosity_distance : `numpy.ndarray`
Luminosity distance of the source in Mpc.
theta_jn : `numpy.ndarray`
Angle between the line of sight and the orbital angular momentum vector.
ra : `numpy.ndarray`
Right ascension of the source in radians.
dec : `numpy.ndarray`
Declination of the source in radians.
time : `numpy.ndarray`
GPS time of the source.
psi : `numpy.ndarray`
Polarization angle of the source.
detector_tensor : array-like
Detector tensor for the multiple detectors (nx3x3 matrix), where n is the number of detectors.
Returns
-------
effective_distance: `numpy.ndarray`
Effective distance of the source in Mpc. Shape is (n, len(ra)).
"""
len_det = len(detector_tensor)
len_param = len(ra)
eff_dist = np.zeros((len_det, len_param))
for i in prange(len_param):
for j in range(len_det):
eff_dist[j,i] = effective_distance(
luminosity_distance[i], theta_jn[i], ra[i], dec[i], geocent_time[i], psi[i], detector_tensor[j]
)
return eff_dist
@njit
[docs]
def noise_weighted_inner_product(
signal1, signal2, psd, duration,
):
"""
Noise weighted inner product of two time series data sets.
Parameters
----------
signal1: `numpy.ndarray` or `float`
First series data set.
signal2: `numpy.ndarray` or `float`
Second series data set.
psd: `numpy.ndarray` or `float`
Power spectral density of the detector.
duration: `float`
Duration of the data.
"""
nwip_arr = np.conj(signal1) * signal2 / psd
return 4 / duration * np.sum(nwip_arr)
@njit(parallel=True)
[docs]
def linear_interpolator(xnew_array, y_array, x_array, fill_value=np.inf):
"""
Linear interpolator for 1D data.
Parameters
----------
xnew_array : `numpy.ndarray`
New x values to interpolate.
y_array : `numpy.ndarray`
y values corresponding to the x_array.
x_array : `numpy.ndarray`
Original x values.
Returns
-------
result : `numpy.ndarray`
Interpolated y values at xnew_array.
"""
result = np.zeros_like(xnew_array)
len_ = xnew_array.shape[0]
for j in prange(len_):
xnew = xnew_array[j]
# Handling extrapolation
i = np.searchsorted(x_array, xnew) - 1
# Linear interpolation
if (i < len(x_array) - 1) and (i > 0):
x0, x1 = x_array[i], x_array[i + 1]
y0, y1 = y_array[i], y_array[i + 1]
result[j] = y0 + (y1 - y0) * (xnew - x0) / (x1 - x0)
else:
result[j] = fill_value
return result
# @njit
# def _helper_hphc(hp,hc,fsize_arr,fs,size,f_l,i):
# # remove the np.nan padding
# hp_ = np.array(hp[i][:fsize_arr[i]], dtype=np.complex128)
# hc_ = np.array(hc[i][:fsize_arr[i]], dtype=np.complex128)
# # find the index of 20Hz or nearby
# # set all elements to zero below this index
# idx = np.abs(fs[i] - f_l).argmin()
# hp_[i][0:idx] = 0.0 + 0.0j
# hc_[i][0:idx] = 0.0 + 0.0j
# return hp_,hc_