gwsnr.ann.ann_model_generator
This module contains the ANNModelGenerator class which is used to generate the ANN (Artificial Neural Network) model that can be used to predict the SNR of the GW events.
Module Contents
Classes
ANNModelGenerator class is used to generate the ANN model that can be used to predict the SNR of the GW events. |
- class gwsnr.ann.ann_model_generator.ANNModelGenerator(directory='./gwsnr_data', npool=4, gwsnr_verbose=True, snr_th=8.0, snr_method='interpolation_aligned_spins', waveform_approximant='IMRPhenomXPHM', **kwargs)[source]
ANNModelGenerator class is used to generate the ANN model that can be used to predict the SNR of the GW events.
- Parameters:
- npoolint
Number of processors to use for parallel processing. Default is 4.
- gwsnr_verbosebool
If True, print the progress of the GWSNR calculation. Default is True.
- snr_thfloat
SNR threshold for the error calculation. Default is 8.0.
- waveform_approximantstr
Waveform approximant to be used for the GWSNR calculation and the ANN model. Default is “IMRPhenomXPHM”.
- **kwargsdict
Keyword arguments for the GWSNR class. To see the list of available arguments, >>> from gwsnr import GWSNR >>> help(GWSNR)
Examples
>>> from gwsnr import ANNModelGenerator >>> amg = ANNModelGenerator() >>> amg.ann_model_training(gw_param_dict='gw_param_dict.json') # training the ANN model with pre-generated parameter points
- get_input_data(params)[source]
Function to generate input and output data for the neural network
Parameters: idx: index of the parameter points params: dictionary of parameter points
params.keys() = [‘mass_1’, ‘mass_2’, ‘luminosity_distance’, ‘theta_jn’, ‘psi’, ‘geocent_time’, ‘ra’, ‘dec’, ‘a_1’, ‘a_2’, ‘tilt_1’, ‘tilt_2’, ‘L1’]
Returns: X: input data, [snr_partial_[0], amp0[0], eta, chi_eff, theta_jn] y: output data, [L1]
- ann_model_training(gw_param_dict, randomize=True, test_size=0.1, random_state=42, num_nodes_list=[5, 32, 32, 1], activation_fn_list=['relu', 'relu', 'sigmoid', 'linear'], optimizer='adam', loss='mean_squared_error', metrics=['accuracy'], batch_size=32, epochs=100, error_adjustment_snr_range=[4, 10], ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name='error_adjustment.json', ann_path_dict_file_name='ann_path_dict.json')[source]
- save_ann_path_dict(ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name='error_adjustment.json', ann_path_dict_file_name='ann_path_dict.json')[source]
- load_model_scaler_error(ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name=False)[source]