{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimal Signal to Noise Ratio (SNR) generation and comparison \n", "\n", "This notebook is a full guide on how to use the `gwsnr` package to generate the 'optimal-SNR' ($\\rho_{\\rm opt}$) and 'probability of detection' (Pdet) for a given Gravitational Wave (GW) signal.\n", "\n", "## Contents of this notebook\n", "\n", "$\\rho_{\\rm opt}$ calculation with,\n", " - Noise-Weighted Inner Product \n", " - Partial Scaling Interpolation\n", " - ANN-model and $P_{\\rm det}$ Estimation \n", " - Hybrid SNR Recalculation for $P_{\\rm det}$ Estimation\n", " - JAX assisted Inner product with `ripplegw` as backend\n", "\n", "Note: For more details on SNR calculation methods, please refer to the [gwsnr documentation](https://gwsnr.hemantaph.com).\n", "\n", "\n", "## Requirements\n", "\n", "- `gwsnr` for all SNR calculations\n", "- `jax` and `jaxlib` for Partial Scaling Interpolation, if you want to use the jax version. `\"jax[cuda12]\"` for running on Nvidia GPU.\n", "- `mlx` for Partial Scaling Interpolation, if you want to use Apple Silicon M-series GPU.\n", "- `ripplegw` for JAX assisted Inner product method.\n", "- `scikit-learn` and `tensorflow` for ANN model based SNR calculation. `ml-dtypes` might need to be updated to the latest version for compatibility with TensorFlow.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Below shows GWSNR initialization with all of its (default) arguments\n", "\n", "```python\n", "gwsnr = gwsnr.GWSNR(\n", " #################################\n", " # General settings\n", " npool=4, # Number of processors for parallel processing. Run this to check the number of cores in your machine; import os; os.cpu_count()\n", " snr_method='interpolation_no_spins', # SNR calculation method. Other options: 'interpolation_no_spins_jax', 'interpolation_no_spins_mlx', 'interpolation_aligned_spins', 'interpolation_aligned_spins_jax', 'interpolation_aligned_spins_mlx', 'inner_product', 'inner_product_jax', 'ann' \n", " snr_type='optimal_snr', # Type of SNR to be calculated. 'matched_filter_snr' option will be available in future releases.\n", " gwsnr_verbose=True, # If True, it will print all gwsnr settings\n", " multiprocessing_verbose=True, # If True, it will show progress bar for multiprocessing. \n", " pdet_kwargs=dict(snr_th=10.0, snr_th_net=10.0, pdet_type='boolean', distribution_type='noncentral_chi2'), # Dictionary of Pdet settings\n", " #################################\n", " # Settings for interpolation grid\n", " mtot_min=2*4.98, # Minimum total mass (Mo) for interpolation grid. 4.98 Mo is the minimum component mass of BBH systems in GWTC-3\n", " mtot_max=2*112.5+10.0, # Maximum total mass (Mo) for interpolation grid. 112.5 Mo is the maximum component mass of BBH systems in GWTC-3. 10.0 Mo is added to avoid edge effects.\n", " ratio_min=0.1, # Minimum mass ratio for interpolation grid\n", " ratio_max=1.0, # Maximum mass ratio for interpolation grid\n", " spin_max=0.99, # Maximum spin magnitude for interpolation grid. Note: spin_min= -spin_max, as in aligned spin systems.\n", " mtot_resolution=200, # Number of points in total mass axis for interpolation grid\n", " ratio_resolution=20, # Number of points in mass ratio axis for interpolation grid\n", " spin_resolution=10, # Number of points in spin magnitude axis for interpolation grid.\n", " batch_size_interpolation=1000000, # Number of samples to be processed in each batch for interpolation method\n", " interpolator_dir='./interpolator_pickle', # Directory to save/load the interpolator\n", " create_new_interpolator=False, # If True, it will overwrite the existing interpolator\n", " #################################\n", " # GW signal settings\n", " sampling_frequency=2048.0, # Sampling frequency in Hz\n", " waveform_approximant='IMRPhenomD', # Frequency domain waveform approximant of the GW signal\n", " frequency_domain_source_model='lal_binary_black_hole', # Source model for frequency domain waveform generation. \n", " minimum_frequency=20.0, # Minimum frequency of the waveform in Hz\n", " reference_frequency=20.0, # Reference frequency for spin\n", " duration_max=None, # Maximum duration of the waveform in seconds. Bilby default for IMRPhenomXPHM is 64 seconds.\n", " duration_min=None, # Minimum duration of the waveform in seconds. Bilby default is 4 seconds.\n", " fixed_duration=None, # If a float value is provided, all waveforms will be generated with this fixed duration (in seconds).\n", " mtot_cut=False, # If True, SNR=0 for total mass associated with signal duration < 1.1*chirp_duration\n", " #################################\n", " # Detector settings\n", " psds= {'L1':'aLIGO_O4_high_asd.txt','H1':'aLIGO_O4_high_asd.txt', 'V1':'AdV_asd.txt', 'K1':'KAGRA_design_asd.txt'}, # Power spectral density of the detectors. Other options: psd names from pycbc (e.g. 'aLIGODesign'), psd via gps time (e.eg. 1234567890), or custom PSD as txt file.\n", " ifos=['L1', 'H1', 'V1'], # List of detectors. You can also provide bilby interferometer objects.\n", " #################################\n", " # ANN settings\n", " ann_path_dict=None, # Path to the ANN model for SNR and Pdet calculation\n", " #################################\n", " # Hybrid SNR recalculation settings\n", " snr_recalculation=False, # If True, enables hybrid SNR recalculation for systems near detection threshold. Default is False.\n", " snr_recalculation_range=[6,14], # SNR range [min, max] for triggering recalculation with inner product method.\n", " snr_recalculation_waveform_approximant='IMRPhenomXPHM', # Waveform approximant to use for SNR recalculation. Default is 'IMRPhenomXPHM'.\n", ")\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Noise-Weighted Inner Product " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "Copying interpolator data from the library resource /Users/phurailatpamhemantakumar/anaconda3/envs/gwsnr3/lib/python3.11/site-packages/gwsnr/core/interpolator_pickle to the current working directory.\n", "psds not given. Choosing bilby's default psds\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "\n", "Chosen GWSNR initialization parameters:\n", "\n", "npool: 4\n", "snr type: inner_product\n", "waveform approximant: IMRPhenomXPHM\n", "sampling frequency: 2048.0\n", "minimum frequency (fmin): 20.0\n", "reference frequency (f_ref): 20.0\n", "mtot=mass1+mass2\n", "min(mtot): 9.96\n", "max(mtot) (with the given fmin=20.0): 235.0\n", "detectors: ['L1', 'H1', 'V1']\n", "psds: [[array([ 10.21659, 10.23975, 10.26296, ..., 4972.81 ,\n", " 4984.081 , 4995.378 ], shape=(2736,)), array([4.43925574e-41, 4.22777986e-41, 4.02102594e-41, ...,\n", " 6.51153524e-46, 6.43165104e-46, 6.55252996e-46],\n", " shape=(2736,)), ], [array([ 10.21659, 10.23975, 10.26296, ..., 4972.81 ,\n", " 4984.081 , 4995.378 ], shape=(2736,)), array([4.43925574e-41, 4.22777986e-41, 4.02102594e-41, ...,\n", " 6.51153524e-46, 6.43165104e-46, 6.55252996e-46],\n", " shape=(2736,)), ], [array([ 10. , 10.02306 , 10.046173, ...,\n", " 9954.0389 , 9976.993 , 10000. ], shape=(3000,)), array([1.22674387e-42, 1.20400299e-42, 1.18169466e-42, ...,\n", " 1.51304203e-43, 1.52010157e-43, 1.52719372e-43],\n", " shape=(3000,)), ]]\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 3.30it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Computed SNRs with inner product:\n", " {'L1': array([46.53022776]), 'H1': array([48.20334347]), 'V1': array([13.20663045]), 'snr_net': array([68.28645184])}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# loading GWSNR class from the gwsnr package\n", "import gwsnr\n", "import numpy as np\n", "\n", "# initializing the GWSNR class with inner product as the signal-to-noise ratio type\n", "gwsnr = gwsnr.GWSNR(\n", " snr_method='inner_product', \n", " waveform_approximant='IMRPhenomXPHM')\n", "\n", "# signal-to-noise ratio for a BBH with GW150914 like parameters with detectors LIGO-Hanford, LIGO-Livingston, and Virgo with O4 observing run sensitivity\n", "snrs = gwsnr.optimal_snr(\n", " mass_1=np.array([36.0]), # mass of the primary black hole in solar masses\n", " mass_2=np.array([29.0]), # mass of the secondary black hole in solar masses\n", " luminosity_distance=np.array([440.0]), # luminosity distance in Mpc\n", " theta_jn=np.array([1.0]), # inclination angle in radians\n", " ra=np.array([3.435]), # right ascension in radians\n", " dec=np.array([-0.408]), # declination in radians\n", " psi=np.array([0.0]), # polarization angle in radians\n", " geocent_time=np.array([1126259462.4]), # geocentric time in GPS seconds\n", " a_1=np.array([0.3]), # dimensionless spin of the primary black hole\n", " a_2=np.array([0.2]), # dimensionless spin of the secondary black hole\n", " tilt_1=np.array([0.5]), # tilt angle of the primary black hole in radians\n", " tilt_2=np.array([0.8]), # tilt angle of the secondary black hole in radians\n", " phi_12=np.array([0.0]), # Relative angle between the primary and secondary spin of the binary in radians\n", " phi_jl=np.array([0.0]), # Angle between the total angular momentum and the orbital angular momentum in radians\n", ")\n", "\n", "print('\\nComputed SNRs with inner product:\\n', snrs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Partial Scaling Interpolation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "psds not given. Choosing bilby's default psds\n", "Interpolator will be generated for L1 detector at ./interpolator_pickle/L1/partialSNR_dict_4.pickle\n", "Interpolator will be generated for H1 detector at ./interpolator_pickle/H1/partialSNR_dict_4.pickle\n", "Interpolator will be generated for V1 detector at ./interpolator_pickle/V1/partialSNR_dict_4.pickle\n", "Please be patient while the interpolator is generated\n", "Generating interpolator for ['L1', 'H1', 'V1'] detectors\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████| 4000/4000 [00:01<00:00, 2797.16it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Saving Partial-SNR for L1 detector with shape (20, 200)\n", "\n", "Saving Partial-SNR for H1 detector with shape (20, 200)\n", "\n", "Saving Partial-SNR for V1 detector with shape (20, 200)\n", "\n", "\n", "Interpolation results: \n", "{'L1': array([ 7.40707612, 11.5964448 , 31.03520024, 26.59768336]), 'H1': array([ 4.7108495 , 7.37525919, 19.73817404, 16.91594381]), 'V1': array([2.22222811, 3.44022184, 9.33438517, 7.86082025]), 'snr_net': array([ 9.05511885, 14.16711355, 37.94614493, 32.48658816])}\n", "\n", " Inner product results: \n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 14.93it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "{'L1': array([ 7.40686229, 11.5956576 , 31.03469553, 26.59785488]), 'H1': array([ 4.71071351, 7.37475854, 19.73785306, 16.91605289]), 'V1': array([2.22218338, 3.44021659, 9.33421055, 7.8608819 ]), 'snr_net': array([ 9.05486221, 14.16620729, 37.94552222, 32.4868003 ])}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import numpy as np\n", "import gwsnr\n", "\n", "# initializing the GWSNR class with default configuration and interpolation method\n", "# for non-spinning IMRPhenomD waveform \n", "gwsnr_no_spins = gwsnr.GWSNR(\n", " npool=4,\n", " snr_method='interpolation_no_spins', # Other options: 'interpolation_no_spins_jax', 'interpolation_no_spins_mlx', 'interpolation_aligned_spins', 'interpolation_aligned_spins_jax', 'interpolation_aligned_spins_mlx'\n", " gwsnr_verbose=False,\n", " waveform_approximant='IMRPhenomD',\n", ")\n", "\n", "# Quick test and comparison between interpolation and inner product methods\n", "mass_1 = np.array([5, 10.,50.,100.])\n", "ratio = np.array([1, 0.8,0.5,0.2])\n", "dl = 1000\n", "print('Interpolation results: ')\n", "print(gwsnr_no_spins.optimal_snr(mass_1=mass_1, mass_2=mass_1*ratio, luminosity_distance=dl))\n", "print('\\n Inner product results: ')\n", "print(gwsnr_no_spins.optimal_snr_with_inner_product(mass_1=mass_1, mass_2=mass_1*ratio, luminosity_distance=dl))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing speed and accuracy between the interpolation method and inner product method.\n", "\n", "* set up the BBH Binary-Black-Holes parameters\n", "* simulate 50000 events\n", "* Consider IMRPhenomD waveform with aligned-spins" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "Copying interpolator data from the library resource /Users/phurailatpamhemantakumar/anaconda3/envs/gwsnr2/lib/python3.10/site-packages/gwsnr/core/interpolator_pickle to the current working directory.\n", "psds not given. Choosing bilby's default psds\n", "Interpolator will be loaded for L1 detector from ./interpolator_pickle/L1/partialSNR_dict_0.pickle\n", "Interpolator will be loaded for H1 detector from ./interpolator_pickle/H1/partialSNR_dict_0.pickle\n", "Interpolator will be loaded for V1 detector from ./interpolator_pickle/V1/partialSNR_dict_0.pickle\n", "\n", "Chosen GWSNR initialization parameters:\n", "\n", "npool: 4\n", "snr type: interpolation_aligned_spins\n", "waveform approximant: IMRPhenomD\n", "sampling frequency: 2048.0\n", "minimum frequency (fmin): 20.0\n", "reference frequency (f_ref): 20.0\n", "mtot=mass1+mass2\n", "min(mtot): 9.96\n", "max(mtot) (with the given fmin=20.0): 235.0\n", "detectors: ['L1', 'H1', 'V1']\n", "psds: [[array([ 10.21659, 10.23975, 10.26296, ..., 4972.81 ,\n", " 4984.081 , 4995.378 ], shape=(2736,)), array([4.43925574e-41, 4.22777986e-41, 4.02102594e-41, ...,\n", " 6.51153524e-46, 6.43165104e-46, 6.55252996e-46],\n", " shape=(2736,)), ], [array([ 10.21659, 10.23975, 10.26296, ..., 4972.81 ,\n", " 4984.081 , 4995.378 ], shape=(2736,)), array([4.43925574e-41, 4.22777986e-41, 4.02102594e-41, ...,\n", " 6.51153524e-46, 6.43165104e-46, 6.55252996e-46],\n", " shape=(2736,)), ], [array([ 10. , 10.02306 , 10.046173, ...,\n", " 9954.0389 , 9976.993 , 10000. ], shape=(3000,)), array([1.22674387e-42, 1.20400299e-42, 1.18169466e-42, ...,\n", " 1.51304203e-43, 1.52010157e-43, 1.52719372e-43],\n", " shape=(3000,)), ]]\n", "\n", "\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from datetime import datetime\n", "import gwsnr\n", "\n", "gwsnr_aligned_spins = gwsnr.GWSNR(\n", " npool=4,\n", " snr_method='interpolation_aligned_spins', \n", " gwsnr_verbose=True,\n", " # waveform_approximant='IMRPhenomD', # default waveform\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# gerneral case, random parameters\n", "# add random seed for reproducibility\n", "np.random.seed(42)\n", "nsamples = 10000\n", "mtot = np.random.uniform(2*4.98, 2*112.5,nsamples)\n", "mass_ratio = np.random.uniform(0.2,1,size=nsamples)\n", "param_dict = dict(\n", " # convert to component masses\n", " mass_1 = mtot / (1 + mass_ratio),\n", " mass_2 = mtot * mass_ratio / (1 + mass_ratio),\n", " # Fix luminosity distance\n", " luminosity_distance = 500*np.ones(nsamples),\n", " # Randomly sample everything else:\n", " theta_jn = np.random.uniform(0,2*np.pi, size=nsamples),\n", " ra = np.random.uniform(0,2*np.pi, size=nsamples), \n", " dec = np.random.uniform(-np.pi/2,np.pi/2, size=nsamples), \n", " psi = np.random.uniform(0,2*np.pi, size=nsamples),\n", " phase = np.random.uniform(0,2*np.pi, size=nsamples),\n", " geocent_time = 1246527224.169434*np.ones(nsamples),\n", ")" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Duration interpolation: 0:00:00.035357\n" ] } ], "source": [ "# for non-spinning IMRPhenomD waveform \n", "# running it for the first time will take longer, as the jit code generate the compiled code\n", "start_time = datetime.now()\n", "interp_snr_aligned_spins = gwsnr_aligned_spins.optimal_snr(gw_param_dict=param_dict)\n", "end_time = datetime.now()\n", "print('Duration interpolation: {}'.format(end_time - start_time))" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████| 10000/10000 [00:03<00:00, 2813.75it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Duration inner-product: 0:00:03.628216\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# error might occur for mchirp>95., if f_min=20. \n", "start_time = datetime.now()\n", "bilby_snr = gwsnr_aligned_spins.optimal_snr_with_inner_product(gw_param_dict=param_dict)\n", "end_time = datetime.now()\n", "print('Duration inner-product: {}'.format(end_time - start_time))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* interpolation method is much faster than inner product method (multiprocessing)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Accuracy check plot" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# compare the SNRs\n", "plt.figure(figsize=(8,4))\n", "\n", "# Left panel: rho_pred vs rho_true\n", "plt.subplot(1, 2, 1)\n", "plt.plot(bilby_snr['snr_net'], interp_snr_aligned_spins['snr_net'], 'o', color='C0')\n", "# diagonal line\n", "max_ = 100\n", "min_ = 0\n", "plt.plot([min_, max_], [min_, max_], 'r--', label='y=x')\n", "plt.xlabel(r\"$\\rho_{\\rm net, opt, pred}$\")\n", "plt.ylabel(r\"$\\rho_{\\rm net, opt, true}$\")\n", "# plt.xscale('log')\n", "# plt.yscale('log')\n", "plt.xlim(min_, max_)\n", "plt.ylim(min_, max_)\n", "plt.legend()\n", "plt.title(r\"Predicted vs. True SNR\")\n", "plt.grid(alpha=0.4)\n", "\n", "# Right panel: SNR difference\n", "plt.subplot(1, 2, 2)\n", "idx = (bilby_snr['snr_net'] > 4) & (bilby_snr['snr_net'] < 12)\n", "plt.plot(bilby_snr['snr_net'][idx], abs(interp_snr_aligned_spins['snr_net'][idx]-bilby_snr['snr_net'][idx]), 'o', color='C0')\n", "plt.xlabel(r\"$\\rho_{\\rm net, opt, true}$\")\n", "plt.ylabel(r\"$|\\rho_{\\rm net, opt, true} - \\rho_{\\rm net, opt, interp}|$\")\n", "plt.title(r\"SNR difference\")\n", "plt.grid(alpha=0.4)\n", "plt.title(r\"SNR difference for $4 < \\rho_{\\rm net, opt, true} < 12$\")\n", "plt.tight_layout() # Adjust layout to prevent overlapping titles/labels\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Maximum percentage difference between interpolated and bilby SNRs: 0.08%\n", "Maximum absolute difference between interpolated and bilby SNRs: 0.1853\n" ] } ], "source": [ "# percentage difference\n", "percent_diff = 100 * abs(interp_snr_aligned_spins['snr_net'] - bilby_snr['snr_net']) / bilby_snr['snr_net']\n", "print(f\"Maximum percentage difference between interpolated and bilby SNRs: {np.max(percent_diff):.2f}%\")\n", "# absolute difference\n", "print(f\"Maximum absolute difference between interpolated and bilby SNRs: {np.max(abs(interp_snr_aligned_spins['snr_net'] - bilby_snr['snr_net'])):.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: There can be indeed some outliers with higher SNR difference, which correspond to the edge cases with high total mass and/or low mass ratio. Events that have negative chirp-times, and likely merge outside (below) the frequency band of interest, can also be found in the outliers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ANN-model and $P_{\\rm det}$ Estimation \n", "\n", "* for more details on $P_{\\rm det}$ calculation, please refer to the [Pdet Description](https://gwsnr.hemantaph.com/probabilityofdetection.html) page." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "psds not given. Choosing bilby's default psds\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "ANN model and scaler path is not given. Using the default path.\n", "ANN model for L1 is loaded from ./ann_data/ann_model_L1.h5.\n", "ANN scaler for L1 is loaded from ./ann_data/scaler_L1.pkl.\n", "ANN error_adjustment for L1 is loaded from ./ann_data/error_adjustment_L1.json.\n", "ANN model for H1 is loaded from ./ann_data/ann_model_H1.h5.\n", "ANN scaler for H1 is loaded from ./ann_data/scaler_H1.pkl.\n", "ANN error_adjustment for H1 is loaded from ./ann_data/error_adjustment_H1.json.\n", "ANN model for V1 is loaded from ./ann_data/ann_model_V1.h5.\n", "ANN scaler for V1 is loaded from ./ann_data/scaler_V1.pkl.\n", "ANN error_adjustment for V1 is loaded from ./ann_data/error_adjustment_V1.json.\n", "Interpolator will be loaded for L1 detector from ./interpolator_pickle/L1/partialSNR_dict_1.pickle\n", "Interpolator will be loaded for H1 detector from ./interpolator_pickle/H1/partialSNR_dict_1.pickle\n", "Interpolator will be loaded for V1 detector from ./interpolator_pickle/V1/partialSNR_dict_1.pickle\n", "\n", "\n", "\n", "Initializing GWSNR class...\n", "\n", "psds not given. Choosing bilby's default psds\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/phurailatpamhemantakumar/anaconda3/envs/gwsnr2/lib/python3.10/site-packages/sklearn/base.py:442: InconsistentVersionWarning: Trying to unpickle estimator StandardScaler from version 1.7.0 when using version 1.7.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", " warnings.warn(\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import gwsnr\n", "\n", "# it will take a while to gerenerate interpolation data for precessing systems.\n", "# interpolation is used for the partial-SNR computation.\n", "# partial-SNR is one of the inputs to compute the full SNR using ANN models.\n", "gwsnr_ann = gwsnr.GWSNR(\n", " npool=4,\n", " snr_method='ann', \n", " waveform_approximant='IMRPhenomXPHM', \n", " gwsnr_verbose=False,\n", " # ann_path_dict='./ann_data/ann_path_dict.json', # ann_path_dict can be used if you generate your own ANN models\n", ")\n", "\n", "gwsnr_bilby = gwsnr.GWSNR(\n", " npool=4,\n", " snr_method='inner_product', \n", " waveform_approximant='IMRPhenomXPHM', \n", " gwsnr_verbose=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████| 10000/10000 [00:14<00:00, 685.96it/s]\n" ] } ], "source": [ "# gerneral case, random parameters\n", "# add random seed for reproducibility\n", "np.random.seed(42)\n", "nsamples = 10000\n", "mtot = np.random.uniform(2*4.98, 2*112.5,nsamples)\n", "mass_ratio = np.random.uniform(0.2,1,size=nsamples)\n", "param_dict = dict(\n", " # convert to component masses\n", " mass_1 = mtot / (1 + mass_ratio),\n", " mass_2 = mtot * mass_ratio / (1 + mass_ratio),\n", " # Fix luminosity distance\n", " luminosity_distance = np.random.uniform(100, 10000, size=nsamples), # Random luminosity distance between 100 and 10000 Mpc\n", " # Randomly sample everything else:\n", " theta_jn = np.random.uniform(0,2*np.pi, size=nsamples),\n", " ra = np.random.uniform(0,2*np.pi, size=nsamples), \n", " dec = np.random.uniform(-np.pi/2,np.pi/2, size=nsamples), \n", " psi = np.random.uniform(0,2*np.pi, size=nsamples),\n", " phase = np.random.uniform(0,2*np.pi, size=nsamples),\n", " geocent_time = 1246527224.169434*np.ones(nsamples),\n", " # spin zero\n", " a_1 = np.random.uniform(0.0,0.8, size=nsamples),\n", " a_2 = np.random.uniform(0.0,0.8, size=nsamples),\n", " tilt_1 = np.random.uniform(0, np.pi, size=nsamples), # tilt angle of the primary black hole in radians\n", " tilt_2 = np.random.uniform(0, np.pi, size=nsamples),\n", " phi_12 = np.random.uniform(0, 2*np.pi, size=nsamples), # Relative angle between the primary and secondary spin of the binary in radians\n", " phi_jl = np.random.uniform(0, 2*np.pi, size=nsamples), # Angle between the total angular momentum and the orbital angular momentum in radians\n", ")\n", "\n", "# ANN model method\n", "ann_pdet = gwsnr_ann.pdet(gw_param_dict=param_dict.copy(), pdet_type='boolean', distribution_type='fixed_snr')\n", "\n", "# inner product method\n", "bilby_pdet = gwsnr_bilby.pdet(gw_param_dict=param_dict.copy(), pdet_type='boolean', distribution_type='fixed_snr')" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Confusion Matrix:\n", "[[6151 209]\n", " [ 126 3514]]\n", "Accuracy: 96.65 %\n" ] } ], "source": [ "# accuracy of the probability of detection\n", "from sklearn.metrics import confusion_matrix, accuracy_score\n", "\n", "pdet_pred = ann_pdet['pdet_net']\n", "pdet_true = bilby_pdet['pdet_net']\n", "\n", "cm = confusion_matrix(pdet_true, pdet_pred)\n", "print(\"Confusion Matrix:\")\n", "print(cm)\n", "\n", "acc = accuracy_score(pdet_true, pdet_pred)\n", "print(f\"Accuracy: {acc*100:.2f} %\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hybrid SNR Recalculation for $P_{\\rm det}$ Estimation" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "psds not given. Choosing bilby's default psds\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "ANN model and scaler path is not given. Using the default path.\n", "ANN model for L1 is loaded from ./ann_data/ann_model_L1.h5.\n", "ANN scaler for L1 is loaded from ./ann_data/scaler_L1.pkl.\n", "ANN error_adjustment for L1 is loaded from ./ann_data/error_adjustment_L1.json.\n", "ANN model for H1 is loaded from ./ann_data/ann_model_H1.h5.\n", "ANN scaler for H1 is loaded from ./ann_data/scaler_H1.pkl.\n", "ANN error_adjustment for H1 is loaded from ./ann_data/error_adjustment_H1.json.\n", "ANN model for V1 is loaded from ./ann_data/ann_model_V1.h5.\n", "ANN scaler for V1 is loaded from ./ann_data/scaler_V1.pkl.\n", "ANN error_adjustment for V1 is loaded from ./ann_data/error_adjustment_V1.json.\n", "Interpolator will be loaded for L1 detector from ./interpolator_pickle/L1/partialSNR_dict_5.pickle\n", "Interpolator will be loaded for H1 detector from ./interpolator_pickle/H1/partialSNR_dict_5.pickle\n", "Interpolator will be loaded for V1 detector from ./interpolator_pickle/V1/partialSNR_dict_5.pickle\n", "\n", "\n", "\n", "Initializing GWSNR class...\n", "\n", "psds not given. Choosing bilby's default psds\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/Users/phurailatpamhemantakumar/anaconda3/envs/gwsnr/lib/python3.10/site-packages/sklearn/base.py:442: InconsistentVersionWarning: Trying to unpickle estimator StandardScaler from version 1.7.0 when using version 1.7.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:\n", "https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations\n", " warnings.warn(\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from datetime import datetime\n", "import gwsnr\n", "\n", "\n", "gwsnr_hybrid = gwsnr.GWSNR(\n", " npool=4,\n", " snr_method='ann',\n", " waveform_approximant='IMRPhenomXPHM', # Use a specific waveform approximant\n", " gwsnr_verbose=False,\n", " snr_recalculation=True, # Enable recalculation of SNRs\n", " snr_recalculation_range=[6,18], # Range of optimal SNR for recalculation\n", " snr_recalculation_waveform_approximant='IMRPhenomXPHM', # Waveform approximant for recalculation. This can be different from the main one\n", ")\n", "\n", "gwsnr_bilby = gwsnr.GWSNR(\n", " npool=4,\n", " snr_method='inner_product', \n", " waveform_approximant='IMRPhenomXPHM', \n", " gwsnr_verbose=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Recalculating SNR for 3776 out of 10000 samples in the SNR range of 6 to 18\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████| 3776/3776 [00:04<00:00, 882.75it/s]\n", "100%|████████████████████████████████████████████████████████| 10000/10000 [00:13<00:00, 718.46it/s]\n" ] } ], "source": [ "# gerneral case, random parameters\n", "# add random seed for reproducibility\n", "np.random.seed(42)\n", "nsamples = 10000\n", "mtot = np.random.uniform(2*4.98, 2*112.5,nsamples)\n", "mass_ratio = np.random.uniform(0.2,1,size=nsamples)\n", "param_dict = dict(\n", " # convert to component masses\n", " mass_1 = mtot / (1 + mass_ratio),\n", " mass_2 = mtot * mass_ratio / (1 + mass_ratio),\n", " # Fix luminosity distance\n", " luminosity_distance = np.random.uniform(100, 10000, size=nsamples), # Random luminosity distance between 100 and 10000 Mpc\n", " # Randomly sample everything else:\n", " theta_jn = np.random.uniform(0,2*np.pi, size=nsamples),\n", " ra = np.random.uniform(0,2*np.pi, size=nsamples), \n", " dec = np.random.uniform(-np.pi/2,np.pi/2, size=nsamples), \n", " psi = np.random.uniform(0,2*np.pi, size=nsamples),\n", " phase = np.random.uniform(0,2*np.pi, size=nsamples),\n", " geocent_time = 1246527224.169434*np.ones(nsamples),\n", " # spin zero\n", " a_1 = np.random.uniform(0.0,0.8, size=nsamples),\n", " a_2 = np.random.uniform(0.0,0.8, size=nsamples),\n", " tilt_1 = np.random.uniform(0, np.pi, size=nsamples), # tilt angle of the primary black hole in radians\n", " tilt_2 = np.random.uniform(0, np.pi, size=nsamples),\n", " phi_12 = np.random.uniform(0, 2*np.pi, size=nsamples), # Relative angle between the primary and secondary spin of the binary in radians\n", " phi_jl = np.random.uniform(0, 2*np.pi, size=nsamples), # Angle between the total angular momentum and the orbital angular momentum in radians\n", ")\n", "\n", "# ANN model method, with SNR recalculation\n", "hybrid_pdet = gwsnr_hybrid.pdet(gw_param_dict=param_dict.copy(), pdet_type='boolean', distribution_type='fixed_snr')\n", "\n", "# inner product method\n", "bilby_pdet = gwsnr_bilby.pdet(gw_param_dict=param_dict.copy(), pdet_type='boolean', distribution_type='fixed_snr')" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Confusion Matrix:\n", "[[6360 0]\n", " [ 1 3639]]\n", "Accuracy: 99.99 %\n" ] } ], "source": [ "# accuracy of the probability of detection\n", "from sklearn.metrics import confusion_matrix, accuracy_score\n", "\n", "pdet_pred = hybrid_pdet['pdet_net']\n", "pdet_true = bilby_pdet['pdet_net']\n", "\n", "cm = confusion_matrix(pdet_true, pdet_pred)\n", "print(\"Confusion Matrix:\")\n", "print(cm)\n", "\n", "acc = accuracy_score(pdet_true, pdet_pred)\n", "print(f\"Accuracy: {acc*100:.2f} %\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## JAX assisted Inner product with `ripplegw` as backend" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "psds not given. Choosing bilby's default psds\n", "\n", "\n" ] } ], "source": [ "import numpy as np\n", "import gwsnr\n", "\n", "# innitialize the class for pdet calculation\n", "gwsnr = gwsnr.GWSNR(\n", " snr_method='inner_product_jax', \n", " waveform_approximant='IMRPhenomXAS', \n", " gwsnr_verbose=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/phurailatpamhemantakumar/anaconda3/envs/gwsnr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", " self.pid = os.fork()\n", "100%|█████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 8.52it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SNR (inner product JAX assisted) : [11.35192063 9.22035654 17.49730142 17.53788394]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n", "100%|█████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 12.69it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SNR (inner product bilby) : [11.35192174 9.22036535 17.49712218 17.53863404]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# set the GW parameters\n", "mass_1 = np.array([5, 10.,50.,200.])\n", "ratio = np.array([1, 0.8,0.5,0.2])\n", "param_dict = dict(\n", " mass_1 = mass_1,\n", " mass_2 = mass_1*ratio,\n", " luminosity_distance = np.array([1000, 2000, 3000, 4000]),\n", " theta_jn = np.array([0.1, 0.2, 0.3, 0.4]),\n", " ra = np.array([0.1, 0.2, 0.3, 0.4]), \n", " dec = np.array([0.1, 0.2, 0.3, 0.4]), \n", " psi = np.array([0.1, 0.2, 0.3, 0.4]),\n", " a_1 = np.array([0.1, 0.2, 0.3, 0.4]),\n", " a_2 = np.array([0.1, 0.2, 0.3, 0.4]),\n", " geocent_time = np.array([0.0, 0.0, 0.0, 0.0]),\n", " phase = np.array([0.0, 0.0, 0.0, 0.0]),\n", ")\n", "\n", "# jax.jit functions are slow when run for the first time\n", "snr_jax = gwsnr.optimal_snr(gw_param_dict=param_dict)\n", "print(\"SNR (inner product JAX assisted) : \", snr_jax[\"snr_net\"])\n", "\n", "# snr with inner product\n", "snr_bilby = gwsnr.optimal_snr_with_inner_product(gw_param_dict=param_dict)\n", "print(\"SNR (inner product bilby) : \", snr_bilby[\"snr_net\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* as of 27 Oct 2025, jax implementation of inner product is not fully optimized yet.\n", "\n", "* While the waveform generation is fast via `jax.vamp` parallelization, the overall inner product calculation has extra overhead which will be optimized in the future." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Changing the initialization arguments\n", "\n", "To change the initialization arguments, you can modify the `GWSNR` class initialization parameters. Below is an example of how to set up the `GWSNR` class with different parameters:\n", "\n", "What will change in this example? \n", "\n", "* ifos: CE, ET ; interferometers will be changed to Cosmic Explorer (CE) and Einstein Telescope (ET). \n", "\n", "* minimum_frequency: 10 Hz\n", "\n", "* mtot_cut: True ; this will set SNR=0 for total mass > mtot_max, and mtot_max is set according to chirp_time and minimum frequency.\n", "\n", "* Waveform model: TaylorF2\n", "\n", "* multiprocessing_verbose: False ; The progress bar won't be shown but the calculation will be faster.\n", "\n", "* mtot_min: 2*1.0 ; minimum total mass in solar masses Mo. 1.0 Mo is the minimum component mass of BNS systems in GWTC-3\n", "\n", "* snr_method: 'interpolation_no_spins_jax' ; this will use the interpolation method with no spins and JAX acceleration." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "Interpolator will be generated for CE detector at ./interpolator_pickle/CE/partialSNR_dict_1.pickle\n", "Interpolator will be generated for ET1 detector at ./interpolator_pickle/ET1/partialSNR_dict_1.pickle\n", "Interpolator will be generated for ET2 detector at ./interpolator_pickle/ET2/partialSNR_dict_1.pickle\n", "Interpolator will be generated for ET3 detector at ./interpolator_pickle/ET3/partialSNR_dict_1.pickle\n", "Please be patient while the interpolator is generated\n", "Generating interpolator for ['CE', 'ET1', 'ET2', 'ET3'] detectors\n", "\n", "Saving Partial-SNR for CE detector with shape (20, 200)\n", "\n", "Saving Partial-SNR for ET1 detector with shape (20, 200)\n", "\n", "Saving Partial-SNR for ET2 detector with shape (20, 200)\n", "\n", "Saving Partial-SNR for ET3 detector with shape (20, 200)\n", "\n", "Chosen GWSNR initialization parameters:\n", "\n", "npool: 8\n", "snr type: interpolation_no_spins_jax\n", "waveform approximant: TaylorF2\n", "sampling frequency: 2048.0\n", "minimum frequency (fmin): 10.0\n", "reference frequency (f_ref): 10.0\n", "mtot=mass1+mass2\n", "min(mtot): 2\n", "max(mtot) (with the given fmin=10.0): 235.0\n", "detectors: ['CE', 'ET1', 'ET2', 'ET3']\n", "psds: [[array([ 5. , 5.01153 , 5.0230867, ...,\n", " 4977.0194 , 4988.4965 , 5000. ]), array([1.36418639e-44, 1.28941521e-44, 1.21896416e-44, ...,\n", " 3.43993064e-48, 3.46829426e-48, 3.49694151e-48]), ], [array([1.0000000e+00, 1.0030759e+00, 1.0061612e+00, ...,\n", " 9.9387655e+03, 9.9693357e+03, 1.0000000e+04]), array([1.22216783e-33, 1.25637893e-33, 1.28147776e-33, ...,\n", " 2.90395567e-47, 2.92250209e-47, 2.94117631e-47]), ], [array([1.0000000e+00, 1.0030759e+00, 1.0061612e+00, ...,\n", " 9.9387655e+03, 9.9693357e+03, 1.0000000e+04]), array([1.22216783e-33, 1.25637893e-33, 1.28147776e-33, ...,\n", " 2.90395567e-47, 2.92250209e-47, 2.94117631e-47]), ], [array([1.0000000e+00, 1.0030759e+00, 1.0061612e+00, ...,\n", " 9.9387655e+03, 9.9693357e+03, 1.0000000e+04]), array([1.22216783e-33, 1.25637893e-33, 1.28147776e-33, ...,\n", " 2.90395567e-47, 2.92250209e-47, 2.94117631e-47]), ]]\n", "\n", "\n" ] } ], "source": [ "# if snr_method = 'inner_product', interpolator will not be created\n", "import gwsnr\n", "gwsnr = gwsnr.GWSNR(\n", " npool = int(8), \n", " minimum_frequency = 10.,\n", " mtot_min = 2*1, \n", " mtot_cut=True,\n", " waveform_approximant='TaylorF2',\n", " snr_method = 'interpolation_no_spins_jax',\n", " #psds = {'CE':'CE_psd.txt', 'ET':'ET_B_psd.txt'}, # if you want to use your own psd\n", " ifos = ['CE', 'ET'], # this will considet bilby's default psd of CE and ET\n", " multiprocessing_verbose=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'CE': array([ 22.20392799, 36.01436234, 112.26751709, 131.33247375]),\n", " 'ET1': array([1.57145751, 2.54927254, 7.9575057 , 9.31898499]),\n", " 'ET2': array([ 4.48666096, 7.27841663, 22.71943855, 26.60659027]),\n", " 'ET3': array([ 4.99195719, 8.09812546, 25.27814484, 29.60307503]),\n", " 'snr_net': array([ 23.2493782 , 37.71058273, 117.56902313, 137.54750061])}" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mass_1 = np.array([5, 10.,50.,100.])\n", "ratio = np.array([1, 0.8,0.5,0.2])\n", "dl = 10000 * np.ones_like(mass_1)\n", "gwsnr.optimal_snr(mass_1=mass_1, mass_2=mass_1*ratio, luminosity_distance=dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "gwsnr2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.19" } }, "nbformat": 4, "nbformat_minor": 2 }