gwsnr.ann
Submodules
Package Contents
Classes
ANNModelGenerator class is used to generate the ANN model that can be used to predict the SNR of the GW events. |
- class gwsnr.ann.ANNModelGenerator(directory='./gwsnr_data', npool=4, gwsnr_verbose=True, snr_th=8.0, snr_method='interpolation_aligned_spins', waveform_approximant='IMRPhenomXPHM', **kwargs)[source]
ANNModelGenerator class is used to generate the ANN model that can be used to predict the SNR of the GW events.
- Parameters:
- npoolint
Number of processors to use for parallel processing. Default is 4.
- gwsnr_verbosebool
If True, print the progress of the GWSNR calculation. Default is True.
- snr_thfloat
SNR threshold for the error calculation. Default is 8.0.
- waveform_approximantstr
Waveform approximant to be used for the GWSNR calculation and the ANN model. Default is “IMRPhenomXPHM”.
- **kwargsdict
Keyword arguments for the GWSNR class. To see the list of available arguments, >>> from gwsnr import GWSNR >>> help(GWSNR)
Examples
>>> from gwsnr import ANNModelGenerator >>> amg = ANNModelGenerator() >>> amg.ann_model_training(gw_param_dict='gw_param_dict.json') # training the ANN model with pre-generated parameter points
- directory = "'./gwsnr_data'"
- ann_model
- ann = 'None'
- scaler = 'None'
- gwsnr_args
- gwsnr
- get_input_data(params)[source]
Function to generate input and output data for the neural network
Parameters: idx: index of the parameter points params: dictionary of parameter points
params.keys() = [‘mass_1’, ‘mass_2’, ‘luminosity_distance’, ‘theta_jn’, ‘psi’, ‘geocent_time’, ‘ra’, ‘dec’, ‘a_1’, ‘a_2’, ‘tilt_1’, ‘tilt_2’, ‘L1’]
Returns: X: input data, [snr_partial_[0], amp0[0], eta, chi_eff, theta_jn] y: output data, [L1]
- ann_model_training(gw_param_dict, randomize=True, test_size=0.1, random_state=42, num_nodes_list=[5, 32, 32, 1], activation_fn_list=['relu', 'relu', 'sigmoid', 'linear'], optimizer='adam', loss='mean_squared_error', metrics=['accuracy'], batch_size=32, epochs=100, error_adjustment_snr_range=[4, 10], ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name='error_adjustment.json', ann_path_dict_file_name='ann_path_dict.json')[source]
- save_ann_path_dict(ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name='error_adjustment.json', ann_path_dict_file_name='ann_path_dict.json')[source]
- load_model_scaler_error(ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name=False)[source]