gwsnr.ann.ann_model_generator
ANN Model Generator for Gravitational Wave SNR Prediction.
This module provides the ANNModelGenerator class for training artificial neural network (ANN) models to predict gravitational wave signal-to-noise ratios (SNR). It enables fast training of custom ANN models for SNR prediction, supporting various detector configurations and waveform approximants. It integrates with the GWSNR interpolation framework for feature extraction and provides tools for model evaluation and error correction.
Key Features: - ANN model training for single-detector SNR prediction
Feature extraction using interpolated partial-scaled SNR values
StandardScaler normalization for input features
Linear error adjustment for improved prediction accuracy
Confusion matrix and accuracy evaluation for detection classification
Model, scaler, and configuration persistence
Copyright (C) 2025 Hemantakumar Phurailatpam and Otto Hannuksela. Distributed under MIT License.
Module Contents
Classes
Generate and train ANN models for gravitational wave SNR prediction. |
- class gwsnr.ann.ann_model_generator.ANNModelGenerator(directory='./gwsnr_data', npool=4, gwsnr_verbose=True, snr_th=8.0, snr_method='interpolation_aligned_spins', waveform_approximant='IMRPhenomXPHM', **kwargs)[source]
Generate and train ANN models for gravitational wave SNR prediction.
Provides functionality to train artificial neural network models that predict optimal SNR for gravitational wave signals from compact binary coalescences. Uses interpolated partial-scaled SNR values as input features along with intrinsic binary parameters.
Key Features: - TensorFlow/Keras-based ANN model training
Feature extraction using GWSNR interpolation framework
StandardScaler normalization for input features
Linear error adjustment for improved prediction accuracy
Model evaluation with confusion matrix and accuracy metrics
- Parameters:
- directory
str Output directory for saving trained models, scalers, and configurations.
default: ‘./gwsnr_data’
- npool
int Number of processors for parallel GWSNR computation.
default: 4
- gwsnr_verbose
bool If True, print GWSNR initialization progress.
default: True
- snr_th
float SNR threshold for detection classification.
default: 8.0
- snr_method
str SNR calculation method for GWSNR initialization.
default: ‘interpolation_aligned_spins’
- waveform_approximant
str Waveform approximant for SNR calculation and ANN training.
default: ‘IMRPhenomXPHM’
- **kwargs
dict Additional keyword arguments passed to
GWSNR.See GWSNR documentation for available options.
- directory
Notes
ANN input features: [partial_scaled_snr, amplitude_factor, eta, chi_eff, theta_jn]
Training requires pre-generated GW parameter samples with computed SNR values
Error adjustment improves predictions via linear correction: y_adj = y_pred - (a*y_pred + b)
Only single-detector training is supported per instance
Examples
Basic ANN model training:
>>> from gwsnr import ANNModelGenerator >>> amg = ANNModelGenerator() >>> amg.ann_model_training(gw_param_dict='gw_param_dict.json')
Custom configuration:
>>> amg = ANNModelGenerator( directory='./custom_output', snr_th=10.0, waveform_approximant='IMRPhenomD') >>> amg.ann_model_training( gw_param_dict=params, epochs=200, batch_size=64)
Instance Methods
ANNModelGenerator class has the following methods:
Method
Description
Train ANN model with parameter data
Load pre-trained model, scaler, and error data
Predict SNR using trained ANN model
Predict detection probability using trained ANN model
Calculate detection probability error rate
Generate confusion matrix for Pdet evaluation
Update and save error adjustment parameters
Instance Attributes
ANNModelGenerator class has the following attributes:
Attribute
Type
Unit
Description
strOutput directory for model files
functionANN model constructor function
Model/NoneTrained Keras model instance
Scaler/NoneStandardScaler for feature normalization
dictGWSNR initialization arguments
GWSNRGWSNR instance for interpolation
ndarrayScaled test input features
ndarrayTest output labels (SNR values)
dictError correction parameters (slope, intercept)
- property directory[source]
Output directory for model files.
- Returns:
- directory
str Output directory path for saving trained models, scalers, and configurations.
default: ‘./gwsnr_data’
- directory
- property ann_model[source]
ANN model constructor function.
- Returns:
- ann_model
function Function that creates and compiles a Keras Sequential model.
- ann_model
- property ann[source]
Trained Keras model instance.
- Returns:
- ann
tensorflow.keras.ModelorNone Trained ANN model, or None if not yet trained/loaded.
- ann
- property scaler[source]
StandardScaler for feature normalization.
- Returns:
- scaler
sklearn.preprocessing.StandardScalerorNone Fitted scaler for input feature normalization, or None if not fitted.
- scaler
- property gwsnr_args[source]
GWSNR initialization arguments.
- Returns:
- gwsnr_args
dict Dictionary of GWSNR configuration parameters.
- gwsnr_args
- property gwsnr[source]
GWSNR instance for interpolation.
- Returns:
- gwsnr
GWSNR GWSNR instance used for partial-scaled SNR interpolation.
- gwsnr
- property X_test[source]
Scaled test input features.
- Returns:
- X_test
numpy.ndarrayorNone Scaled test input feature array, or None if not set.
- X_test
- property y_test[source]
Test output labels (SNR values).
- Returns:
- y_test
numpy.ndarrayorNone Test output SNR values array, or None if not set.
- y_test
- property error_adjustment[source]
Error correction parameters.
- Returns:
- error_adjustment
dictorNone Dictionary with ‘slope’ and ‘intercept’ keys for linear error correction, or None if not computed.
- error_adjustment
- ann_model_training(gw_param_dict, randomize=True, test_size=0.1, random_state=42, num_nodes_list=[5, 32, 32, 1], activation_fn_list=['relu', 'relu', 'sigmoid', 'linear'], optimizer='adam', loss='mean_squared_error', metrics=['accuracy'], batch_size=32, epochs=100, error_adjustment_snr_range=[4, 10], ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name='error_adjustment.json', ann_path_dict_file_name='ann_path_dict.json')[source]
Train ANN model for SNR prediction using GW parameter data.
Complete training pipeline including data preparation, model training, error adjustment calculation, and file saving.
- Parameters:
- gw_param_dict
dictorstr GW parameter dictionary or path to JSON file containing training data.
- randomize
bool If True, randomly shuffle the training data.
default: True
- test_size
float Fraction of data to hold out for testing.
default: 0.1
- random_state
int Random state for train/test split reproducibility.
default: 42
- num_nodes_list
listofint Number of nodes in each layer.
default: [5, 32, 32, 1]
- activation_fn_list
listofstr Activation functions for each layer.
default: [‘relu’, ‘relu’, ‘sigmoid’, ‘linear’]
- optimizer
str Keras optimizer name.
default: ‘adam’
- loss
str Keras loss function name.
default: ‘mean_squared_error’
- metrics
listofstr Metrics to evaluate during training.
default: [‘accuracy’]
- batch_size
int Batch size for training.
default: 32
- epochs
int Number of training epochs.
default: 100
- error_adjustment_snr_range
listoffloat SNR range [min, max] for computing error adjustment parameters.
default: [4, 10]
- ann_file_name
str Output filename for trained model.
default: ‘ann_model.h5’
- scaler_file_name
str Output filename for fitted scaler.
default: ‘scaler.pkl’
- error_adjustment_file_name
str Output filename for error adjustment parameters.
default: ‘error_adjustment.json’
- ann_path_dict_file_name
str Output filename for ANN configuration paths.
default: ‘ann_path_dict.json’
- gw_param_dict
Notes
Saves model to
directory/{ann_file_name}Saves scaler to
directory/{scaler_file_name}Computes error adjustment using
_helper_error_adjustment()
Examples
>>> amg = ANNModelGenerator() >>> amg.ann_model_training( gw_param_dict='training_params.json', epochs=200, batch_size=64)
- pdet_error(gw_param_dict=None, randomize=True, error_adjustment=True)[source]
Calculate detection probability error rate.
Evaluates the percentage of samples where predicted and true detection status (SNR > threshold) differ.
- Parameters:
- gw_param_dict
dictorstr GW parameter dictionary or JSON file path. If None, uses stored test data. Optional.
- randomize
bool If True, randomly shuffle parameters (only used if gw_param_dict provided).
default: True
- error_adjustment
bool If True, apply linear error correction to predictions.
default: True
- gw_param_dict
- Returns:
- error
float Percentage of misclassified samples.
- y_pred
numpy.ndarray Predicted SNR values (with or without error adjustment).
- y_test
numpy.ndarray True SNR values.
- error
Notes
Uses
gwsnr_args['snr_th']as detection thresholdError adjustment: y_adj = y_pred - (slope*y_pred + intercept)
- pdet_confusion_matrix(gw_param_dict=None, randomize=True, snr_threshold=8.0)[source]
Generate confusion matrix for detection probability classification.
Evaluates ANN predictions as binary classification (detected/not detected) and computes confusion matrix and accuracy metrics.
- Parameters:
- gw_param_dict
dictorstr GW parameter dictionary or JSON file path. If None, uses stored test data. Optional.
- randomize
bool If True, randomly shuffle parameters (only used if gw_param_dict provided).
default: True
- snr_threshold
float SNR threshold for detection classification.
default: 8.0
- gw_param_dict
- Returns:
- cm
numpy.ndarray Confusion matrix of shape (2, 2).
- accuracy
float Classification accuracy percentage.
- y_pred
numpy.ndarray Predicted detection status (boolean array).
- y_test
numpy.ndarray True detection status (boolean array).
- cm
Notes
Uses sklearn.metrics.confusion_matrix and accuracy_score
Prints confusion matrix and accuracy to stdout
- load_model_scaler_error(ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name=False)[source]
Load pre-trained ANN model, scaler, and optionally error adjustment.
Restores saved model components for prediction use.
- Parameters:
- ann_file_name
str Filename of the trained model.
default: ‘ann_model.h5’
- scaler_file_name
str Filename of the fitted scaler.
default: ‘scaler.pkl’
- error_adjustment_file_name
strorbool Filename of error adjustment parameters. If False, not loaded.
default: False
- ann_file_name
- Returns:
- ann
tensorflow.keras.Model Loaded Keras model.
- scaler
sklearn.preprocessing.StandardScaler Loaded scaler.
- error_adjustment
dict Error adjustment parameters (only returned if error_adjustment_file_name is provided). Optional.
- ann
Notes
Updates
ann,scaler, and optionallyerror_adjustmentFiles are loaded from
directory
- snr_error_adjustment(gw_param_dict=None, randomize=True, snr_range=[4, 10], error_adjustment_file_name='error_adjustment.json')[source]
Recalculate and save error adjustment parameters.
Computes new error adjustment based on current predictions and updates the stored parameters.
- Parameters:
- gw_param_dict
dictorstr GW parameter dictionary or JSON file path for evaluation. Optional.
- randomize
bool If True, randomly shuffle parameters.
default: True
- snr_range
listoffloat SNR range for error adjustment fitting.
default: [4, 10]
- error_adjustment_file_name
str Output filename for updated error adjustment.
default: ‘error_adjustment.json’
- gw_param_dict
- Returns:
- error_adjustment
dict Updated error adjustment parameters with ‘slope’ and ‘intercept’.
- error_adjustment
Notes
Calls
pdet_error()with error_adjustment=True for predictionsSaves updated parameters to
directory/{error_adjustment_file_name}
- predict_snr(gw_param_dict, error_adjustment=True)[source]
Predict SNR values using trained ANN model.
Applies the trained model to new GW parameters and optionally applies error correction.
- Parameters:
- gw_param_dict
dictorstr GW parameter dictionary or path to JSON file.
- error_adjustment
bool If True, apply linear error correction to predictions.
default: True
- gw_param_dict
- Returns:
- y_pred
numpy.ndarray Predicted SNR values (corrected if error_adjustment=True).
- y_pred
Notes
- predict_pdet(gw_param_dict, snr_threshold=8.0, error_adjustment=True)[source]
Predict detection probability using trained ANN model.
Classifies events as detected (SNR > threshold) or not detected based on ANN predictions.
- Parameters:
- gw_param_dict
dictorstr GW parameter dictionary or path to JSON file.
- snr_threshold
float SNR threshold for detection classification.
default: 8.0
- error_adjustment
bool If True, apply error correction before thresholding.
default: True
- gw_param_dict
- Returns:
- y_pred
numpy.ndarrayofbool Detection status for each sample (True = detected).
- y_pred
Notes
Calls
predict_snr()for SNR predictionReturns boolean array: y_pred > snr_threshold