gwsnr.ann

Submodules

Package Contents

Classes

ANNModelGenerator

Generate and train ANN models for gravitational wave SNR prediction.

class gwsnr.ann.ANNModelGenerator(directory='./gwsnr_data', npool=4, gwsnr_verbose=True, snr_th=8.0, snr_method='interpolation_aligned_spins', waveform_approximant='IMRPhenomXPHM', **kwargs)[source]

Generate and train ANN models for gravitational wave SNR prediction.

Provides functionality to train artificial neural network models that predict optimal SNR for gravitational wave signals from compact binary coalescences. Uses interpolated partial-scaled SNR values as input features along with intrinsic binary parameters.

Key Features: - TensorFlow/Keras-based ANN model training

  • Feature extraction using GWSNR interpolation framework

  • StandardScaler normalization for input features

  • Linear error adjustment for improved prediction accuracy

  • Model evaluation with confusion matrix and accuracy metrics

Parameters:
directorystr

Output directory for saving trained models, scalers, and configurations.

default: ‘./gwsnr_data’

npoolint

Number of processors for parallel GWSNR computation.

default: 4

gwsnr_verbosebool

If True, print GWSNR initialization progress.

default: True

snr_thfloat

SNR threshold for detection classification.

default: 8.0

snr_methodstr

SNR calculation method for GWSNR initialization.

default: ‘interpolation_aligned_spins’

waveform_approximantstr

Waveform approximant for SNR calculation and ANN training.

default: ‘IMRPhenomXPHM’

**kwargsdict

Additional keyword arguments passed to GWSNR.

See GWSNR documentation for available options.

Notes

  • ANN input features: [partial_scaled_snr, amplitude_factor, eta, chi_eff, theta_jn]

  • Training requires pre-generated GW parameter samples with computed SNR values

  • Error adjustment improves predictions via linear correction: y_adj = y_pred - (a*y_pred + b)

  • Only single-detector training is supported per instance

Examples

Basic ANN model training:

>>> from gwsnr import ANNModelGenerator
>>> amg = ANNModelGenerator()
>>> amg.ann_model_training(gw_param_dict='gw_param_dict.json')

Custom configuration:

>>> amg = ANNModelGenerator(
        directory='./custom_output',
        snr_th=10.0,
        waveform_approximant='IMRPhenomD')
>>> amg.ann_model_training(
        gw_param_dict=params,
        epochs=200,
        batch_size=64)

Instance Methods

ANNModelGenerator class has the following methods:

Method

Description

ann_model_training()

Train ANN model with parameter data

load_model_scaler_error()

Load pre-trained model, scaler, and error data

predict_snr()

Predict SNR using trained ANN model

predict_pdet()

Predict detection probability using trained ANN model

pdet_error()

Calculate detection probability error rate

pdet_confusion_matrix()

Generate confusion matrix for Pdet evaluation

snr_error_adjustment()

Update and save error adjustment parameters

Instance Attributes

ANNModelGenerator class has the following attributes:

Attribute

Type

Unit

Description

directory

str

Output directory for model files

ann_model

function

ANN model constructor function

ann

Model/None

Trained Keras model instance

scaler

Scaler/None

StandardScaler for feature normalization

gwsnr_args

dict

GWSNR initialization arguments

gwsnr

GWSNR

GWSNR instance for interpolation

X_test

ndarray

Scaled test input features

y_test

ndarray

Test output labels (SNR values)

error_adjustment

dict

Error correction parameters (slope, intercept)

property directory

Output directory for model files.

Returns:
directorystr

Output directory path for saving trained models, scalers, and configurations.

default: ‘./gwsnr_data’

property ann_model

ANN model constructor function.

Returns:
ann_modelfunction

Function that creates and compiles a Keras Sequential model.

property ann

Trained Keras model instance.

Returns:
anntensorflow.keras.Model or None

Trained ANN model, or None if not yet trained/loaded.

property scaler

StandardScaler for feature normalization.

Returns:
scalersklearn.preprocessing.StandardScaler or None

Fitted scaler for input feature normalization, or None if not fitted.

property gwsnr_args

GWSNR initialization arguments.

Returns:
gwsnr_argsdict

Dictionary of GWSNR configuration parameters.

property gwsnr

GWSNR instance for interpolation.

Returns:
gwsnrGWSNR

GWSNR instance used for partial-scaled SNR interpolation.

property X_test

Scaled test input features.

Returns:
X_testnumpy.ndarray or None

Scaled test input feature array, or None if not set.

property y_test

Test output labels (SNR values).

Returns:
y_testnumpy.ndarray or None

Test output SNR values array, or None if not set.

property error_adjustment

Error correction parameters.

Returns:
error_adjustmentdict or None

Dictionary with ‘slope’ and ‘intercept’ keys for linear error correction, or None if not computed.

ann_model_training(gw_param_dict, randomize=True, test_size=0.1, random_state=42, num_nodes_list=[5, 32, 32, 1], activation_fn_list=['relu', 'relu', 'sigmoid', 'linear'], optimizer='adam', loss='mean_squared_error', metrics=['accuracy'], batch_size=32, epochs=100, error_adjustment_snr_range=[4, 10], ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name='error_adjustment.json', ann_path_dict_file_name='ann_path_dict.json')[source]

Train ANN model for SNR prediction using GW parameter data.

Complete training pipeline including data preparation, model training, error adjustment calculation, and file saving.

Parameters:
gw_param_dictdict or str

GW parameter dictionary or path to JSON file containing training data.

randomizebool

If True, randomly shuffle the training data.

default: True

test_sizefloat

Fraction of data to hold out for testing.

default: 0.1

random_stateint

Random state for train/test split reproducibility.

default: 42

num_nodes_listlist of int

Number of nodes in each layer.

default: [5, 32, 32, 1]

activation_fn_listlist of str

Activation functions for each layer.

default: [‘relu’, ‘relu’, ‘sigmoid’, ‘linear’]

optimizerstr

Keras optimizer name.

default: ‘adam’

lossstr

Keras loss function name.

default: ‘mean_squared_error’

metricslist of str

Metrics to evaluate during training.

default: [‘accuracy’]

batch_sizeint

Batch size for training.

default: 32

epochsint

Number of training epochs.

default: 100

error_adjustment_snr_rangelist of float

SNR range [min, max] for computing error adjustment parameters.

default: [4, 10]

ann_file_namestr

Output filename for trained model.

default: ‘ann_model.h5’

scaler_file_namestr

Output filename for fitted scaler.

default: ‘scaler.pkl’

error_adjustment_file_namestr

Output filename for error adjustment parameters.

default: ‘error_adjustment.json’

ann_path_dict_file_namestr

Output filename for ANN configuration paths.

default: ‘ann_path_dict.json’

Notes

  • Saves model to directory/{ann_file_name}

  • Saves scaler to directory/{scaler_file_name}

  • Computes error adjustment using _helper_error_adjustment()

  • Stores test data in X_test and y_test

Examples

>>> amg = ANNModelGenerator()
>>> amg.ann_model_training(
        gw_param_dict='training_params.json',
        epochs=200,
        batch_size=64)
pdet_error(gw_param_dict=None, randomize=True, error_adjustment=True)[source]

Calculate detection probability error rate.

Evaluates the percentage of samples where predicted and true detection status (SNR > threshold) differ.

Parameters:
gw_param_dictdict or str

GW parameter dictionary or JSON file path. If None, uses stored test data. Optional.

randomizebool

If True, randomly shuffle parameters (only used if gw_param_dict provided).

default: True

error_adjustmentbool

If True, apply linear error correction to predictions.

default: True

Returns:
errorfloat

Percentage of misclassified samples.

y_prednumpy.ndarray

Predicted SNR values (with or without error adjustment).

y_testnumpy.ndarray

True SNR values.

Notes

  • Uses gwsnr_args['snr_th'] as detection threshold

  • Error adjustment: y_adj = y_pred - (slope*y_pred + intercept)

pdet_confusion_matrix(gw_param_dict=None, randomize=True, snr_threshold=8.0)[source]

Generate confusion matrix for detection probability classification.

Evaluates ANN predictions as binary classification (detected/not detected) and computes confusion matrix and accuracy metrics.

Parameters:
gw_param_dictdict or str

GW parameter dictionary or JSON file path. If None, uses stored test data. Optional.

randomizebool

If True, randomly shuffle parameters (only used if gw_param_dict provided).

default: True

snr_thresholdfloat

SNR threshold for detection classification.

default: 8.0

Returns:
cmnumpy.ndarray

Confusion matrix of shape (2, 2).

accuracyfloat

Classification accuracy percentage.

y_prednumpy.ndarray

Predicted detection status (boolean array).

y_testnumpy.ndarray

True detection status (boolean array).

Notes

  • Uses sklearn.metrics.confusion_matrix and accuracy_score

  • Prints confusion matrix and accuracy to stdout

load_model_scaler_error(ann_file_name='ann_model.h5', scaler_file_name='scaler.pkl', error_adjustment_file_name=False)[source]

Load pre-trained ANN model, scaler, and optionally error adjustment.

Restores saved model components for prediction use.

Parameters:
ann_file_namestr

Filename of the trained model.

default: ‘ann_model.h5’

scaler_file_namestr

Filename of the fitted scaler.

default: ‘scaler.pkl’

error_adjustment_file_namestr or bool

Filename of error adjustment parameters. If False, not loaded.

default: False

Returns:
anntensorflow.keras.Model

Loaded Keras model.

scalersklearn.preprocessing.StandardScaler

Loaded scaler.

error_adjustmentdict

Error adjustment parameters (only returned if error_adjustment_file_name is provided). Optional.

Notes

snr_error_adjustment(gw_param_dict=None, randomize=True, snr_range=[4, 10], error_adjustment_file_name='error_adjustment.json')[source]

Recalculate and save error adjustment parameters.

Computes new error adjustment based on current predictions and updates the stored parameters.

Parameters:
gw_param_dictdict or str

GW parameter dictionary or JSON file path for evaluation. Optional.

randomizebool

If True, randomly shuffle parameters.

default: True

snr_rangelist of float

SNR range for error adjustment fitting.

default: [4, 10]

error_adjustment_file_namestr

Output filename for updated error adjustment.

default: ‘error_adjustment.json’

Returns:
error_adjustmentdict

Updated error adjustment parameters with ‘slope’ and ‘intercept’.

Notes

  • Calls pdet_error() with error_adjustment=True for predictions

  • Saves updated parameters to directory/{error_adjustment_file_name}

predict_snr(gw_param_dict, error_adjustment=True)[source]

Predict SNR values using trained ANN model.

Applies the trained model to new GW parameters and optionally applies error correction.

Parameters:
gw_param_dictdict or str

GW parameter dictionary or path to JSON file.

error_adjustmentbool

If True, apply linear error correction to predictions.

default: True

Returns:
y_prednumpy.ndarray

Predicted SNR values (corrected if error_adjustment=True).

Notes

  • Requires ann and scaler to be loaded/trained

  • Error adjustment: y_adj = y_pred - (slope*y_pred + intercept)

predict_pdet(gw_param_dict, snr_threshold=8.0, error_adjustment=True)[source]

Predict detection probability using trained ANN model.

Classifies events as detected (SNR > threshold) or not detected based on ANN predictions.

Parameters:
gw_param_dictdict or str

GW parameter dictionary or path to JSON file.

snr_thresholdfloat

SNR threshold for detection classification.

default: 8.0

error_adjustmentbool

If True, apply error correction before thresholding.

default: True

Returns:
y_prednumpy.ndarray of bool

Detection status for each sample (True = detected).

Notes

  • Calls predict_snr() for SNR prediction

  • Returns boolean array: y_pred > snr_threshold