Source code for gwsnr.threshold.snrthresholdfinder

"""
This module implements the SNRThresholdFinder class to determine the optimal SNR threshold for gravitational wave detection using cross-entropy maximization (following Essick et al. 2023).
"""
import multiprocessing as mpp
from tqdm import tqdm
import h5py
import numpy as np
from .crossentropydifference import cross_entropy_difference
# import crossentropydifference
# cross_entropy_difference = crossentropydifference.cross_entropy_difference
from scipy.interpolate import interp1d
from scipy.optimize import minimize_scalar


[docs] class SNRThresholdFinder: """ A class to find the optimal SNR threshold for gravitational wave detection using cross-entropy maximization. Parameters ---------- catalog_file : str Path to the HDF5 file containing the injection catalog data. The file should have something like the following structure (refer to https://zenodo.org/records/16740117): ``` injections.hdf |-- events | |-- z (parameter to me fitted on) | |-- mass1_source (parameter with which the data is to be selected with) | |-- gstlal_far (original_detection_statistic) | |-- observed_snr_net (projected_detection_statistic) ``` original_detection_statistic : dict, optional Dictionary specifying the original detection statistic with keys: 'parameter' (str): Name of the key in the catalog for the original detection statistic. 'threshold' (float): Threshold value for the original detection statistic. Default is {'parameter': 'gstlal_far', 'threshold': 1}. projected_detection_statistic : dict, optional Dictionary specifying the projected detection statistic with keys: 'parameter' (str): Name of the key in the catalog for the projected detection statistic. 'threshold' (float): Threshold value for the projected detection statistic. 'threshold_search_bounds' (tuple): Bounds for the threshold search. Default is {'parameter': 'observed_snr_net', 'threshold': None, 'threshold_search_bounds': (4, 14)}. parameters_to_fit : list of str, optional List of parameter to fit, e.g., ['redshift']. Default is ['redshift']. sample_size : int, optional Number of samples to use for KDE estimation. Default is 10000. selection_range : dict, optional Dictionary specifying the selection range with keys: 'parameter' (str or list): Parameter(s) to apply the selection range on. 'range' (tuple): Tuple specifying the (min, max) range for selection. Default is {'parameter': 'mass1_source', 'range': (5, 200)}. Examples ---------- >>> finder = SNRThresholdFinder(catalog_file='injection_catalog.h5') >>> best_thr, del_H, H, H_true, snr_thrs = finder.find_threshold(iteration=10) >>> print(f"Best SNR threshold: {best_thr:.2f}") Instance Attributes ------------------- SNRThresholdFinder class has the following attributes: - original_detection_statistic : dict - projected_detection_statistic : dict - parameters_to_fit : list - sample_size : int - selection_range : dict Instance Methods ---------------- SNRThresholdFinder class has the following methods: - det_data : Load and preprocess catalog data - find_threshold : Find the optimal SNR threshold - find_best_SNR_threshold : Find the best SNR threshold using spline interpolation and optimization """ def __init__(self, catalog_file=None, npool=4, selection_range=None, original_detection_statistic=None, projected_detection_statistic=None, parameters_to_fit=None, sample_size=20000, multiprocessing_verbose=True, ):
[docs] self.npool = npool
[docs] self.multiprocessing_verbose = multiprocessing_verbose
if selection_range is None: selection_range = dict( key_name = 'mass1_source', parameter = None, range = (30, 60), )
[docs] self.selection_range = selection_range
if original_detection_statistic is None: self.original_detection_statistic = dict( key_name='gstlal_far', parameter=None, threshold=1, # 1 per year ) else: self.original_detection_statistic = original_detection_statistic if projected_detection_statistic is None: self.projected_detection_statistic = dict( key_name='observed_snr_net', parameter=None, threshold=None, # to be determined threshold_search_bounds=(4, 14), ) else: self.projected_detection_statistic = projected_detection_statistic if parameters_to_fit is None: self.parameters_to_fit = dict( key_name = 'z', parameter = None, ) else: self.parameters_to_fit = parameters_to_fit
[docs] self.sample_size = sample_size
self.det_data(catalog_file);
[docs] def det_data(self, catalog_file, ): """ Function to load and preprocess the injection catalog data from an HDF5 file. Parameters ---------- catalog_file : str Path to the HDF5 file containing the injection catalog data. Returns ------- result_dict : dict Dictionary containing the preprocessed data for the specified parameters and detection statistics. Raises ------ ValueError If 'redshift' is not included in parameters_to_fit. """ def raise_not_provided(param): raise ValueError(f"if catalog_file is not provided, you must provide {param} as list or numpy array.") if catalog_file is None: param = self.selection_range['parameter'] if isinstance(param, list) or isinstance(param, np.ndarray): self.selection_range['parameter'] = np.array(param) else: raise_not_provided('selection_range["parameter"]') param = self.original_detection_statistic['parameter'] if isinstance(param, list) or isinstance(param, np.ndarray): self.original_detection_statistic['parameter'] = np.array(param) else: raise_not_provided('original_detection_statistic["parameter"]') param = self.projected_detection_statistic['parameter'] if isinstance(param, list) or isinstance(param, np.ndarray): self.projected_detection_statistic['parameter'] = np.array(param) else: raise_not_provided('projected_detection_statistic["parameter"]') param = self.parameters_to_fit['parameter'] if isinstance(param, list) or isinstance(param, np.ndarray): self.parameters_to_fit['parameter'] = np.array(param) else: raise_not_provided('parameters_to_fit["parameter"]') else: with h5py.File(catalog_file, 'r') as obj: attrs = dict(obj.attrs.items()) events = obj['events'][:] key_name = self.selection_range['key_name'] if key_name in events.dtype.names: self.selection_range['parameter'] = events[key_name] else: print(f"[WARNING] {key_name} not found in the catalog. Using the parameter array of the same name if provided.") key_name = self.original_detection_statistic['key_name'] self.original_detection_statistic['parameter'] = events[key_name] key_name = self.projected_detection_statistic['key_name'] self.projected_detection_statistic['parameter'] = events[key_name] key_name = self.parameters_to_fit['key_name'] if isinstance(key_name, list): param_array = [] for i, kn in enumerate(key_name): if isinstance(self.parameters_to_fit['parameter'], list) or isinstance(self.parameters_to_fit['parameter'], np.ndarray): param_array.append(self.parameters_to_fit['parameter'][i]) elif kn in events.dtype.names: param_array.append(events[kn]) else: raise ValueError(f"{kn} not found in the catalog. Please provide the parameter array of the same name.") self.parameters_to_fit['parameter'] = np.array(param_array) else: self.parameters_to_fit['parameter'] = events[key_name] # select only events within the selection range min_val = self.selection_range['range'][0] max_val = self.selection_range['range'][1] param = self.selection_range['parameter'] idx_ = (param >= min_val) & (param <= max_val) if np.sum(idx_) == 0: raise ValueError("No injections found within the specified selection range.") dim = len(self.parameters_to_fit['parameter'].shape) if dim < 2: self.parameters_to_fit['parameter'] = self.parameters_to_fit['parameter'][idx_] else: raise NotImplementedError("Selection range filtering for multi-dimensional parameters_to_fit is not implemented yet.") # param_array = [] # for i in range(dim): # param_array.append(self.parameters_to_fit['parameter'][i][idx_]) # self.parameters_to_fit['parameter'] = np.array(param_array) self.original_detection_statistic['parameter'] = self.original_detection_statistic['parameter'][idx_] self.projected_detection_statistic['parameter'] = self.projected_detection_statistic['parameter'][idx_]
[docs] def find_threshold(self, iteration=10, print_output=True, no_multiprocessing=False): """ Function to find the optimal SNR threshold by maximizing the cross-entropy difference. Parameters ---------- iteration : int, optional Number of iterations for threshold search. Default is 10. print_output : bool, optional Whether to print the best SNR threshold. Default is True. Returns ------- best_thr : float The optimal SNR threshold that maximizes the cross-entropy difference. del_H : np.ndarray Array of cross-entropy differences for each threshold tested. H : np.ndarray Array of cross-entropy values for the KDE with cut. H_true : np.ndarray Array of cross-entropy values for the original KDE. snr_thrs : np.ndarray Array of SNR thresholds tested. Raises ------ ValueError If the number of iterations is less than 1. """ snr_thrs = np.linspace( self.projected_detection_statistic['threshold_search_bounds'][0], self.projected_detection_statistic['threshold_search_bounds'][1], iteration ) iters = np.arange(iteration) sample_size = self.sample_size parameters_to_fit = self.parameters_to_fit.copy() original_detection_statistic = self.original_detection_statistic.copy() projected_detection_statistic = self.projected_detection_statistic.copy() # set-up inputs for multoprocessing input_args = [( snr_thr, sample_size, np.array(original_detection_statistic['parameter']), np.array(original_detection_statistic['threshold']), np.array(projected_detection_statistic['parameter']), np.array(parameters_to_fit['parameter']), iters[i] ) for i, snr_thr in enumerate(snr_thrs)] input_args = np.array(input_args, dtype=object) # test with for loop first before using multiprocessing del_H = np.zeros(iteration) H = np.zeros(iteration) H_true = np.zeros(iteration) if no_multiprocessing: for args in tqdm(input_args, total=len(input_args), ncols=100): del_H_i, H_i, H_true_i, iter_i = cross_entropy_difference(args) del_H[iter_i] = del_H_i H[iter_i] = H_i H_true[iter_i] = H_true_i else: print("if multiprocessing get stuck, use no_multiprocessing=True") npool = self.npool with mpp.Pool(processes=npool) as pool: self._multiprocessing_error() if self.multiprocessing_verbose: for result in tqdm( pool.imap_unordered (cross_entropy_difference, input_args), total=len(input_args), ncols=100, ): del_H_i, H_i, H_true_i, iter_i = result del_H[iter_i] = del_H_i H[iter_i] = H_i H_true[iter_i] = H_true_i else: # with map, without tqdm for result in pool.map(cross_entropy_difference, input_args): del_H_i, H_i, H_true_i, iter_i = result del_H[iter_i] = del_H_i H[iter_i] = H_i H_true[iter_i] = H_true_i best_thr = self.find_best_SNR_threshold(snr_thrs, del_H) if print_output: print(f"Best SNR threshold: {best_thr:.2f}") return best_thr, del_H, H, H_true, snr_thrs
[docs] def find_best_SNR_threshold(self, thrs, del_H): """ Function to find the best SNR threshold using spline interpolation and optimization. Parameters ---------- thrs : np.ndarray Array of SNR thresholds tested. del_H : np.ndarray Array of cross-entropy differences for each threshold tested. Returns ------- best_thr : float The optimal SNR threshold that maximizes the cross-entropy difference. """ spline = interp1d(thrs, del_H, kind='cubic') min_bound = np.min(thrs) max_bound = np.max(thrs) best_thr = minimize_scalar(lambda x: -spline(x), bounds=(min_bound, max_bound), method='bounded').x return best_thr
def _multiprocessing_error(self): """ Prints an error message when multiprocessing is used. """ # to access multi-cores instead of multithreading if mpp.current_process().name != 'MainProcess': print( "\n\n[ERROR] This multiprocessing code must be run under 'if __name__ == \"__main__\":'.\n" "Please wrap your script entry point in this guard.\n" "See: https://docs.python.org/3/library/multiprocessing.html#multiprocessing-programming\n" ) raise RuntimeError( "\nMultiprocessing code must be run under 'if __name__ == \"__main__\":'.\n\n" )