{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ANN Model creation and testing\n", "\n", "## Contents\n", "\n", "1. Training data generation\n", "2. ANN model training and testing\n", "3. Implementation of the model in GWSNR" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# # If you have not installed the following packages, please uncomment and run the following command:\n", "# !pip install ler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Training data generation\n", "\n", "* The training data is generated using [ler](https://ler.readthedocs.io/en/latest/) package.\n", "* Data needs to be trained for each detector separately.\n", "* I will choose 'L1' detector for this notebook with the following specified parameters:\n", " * Sampling frequency : 2048 Hz\n", " * waveform approximant : IMRPhenomXPHM\n", " * minimum frequency : 20.0\n", " * psd : aLIGO_O4_high_asd.txt from `pycbc` package" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from ler.utils import TrainingDataGenerator" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "tdg = TrainingDataGenerator(\n", " npool=8, # number of processes\n", " verbose=False, # set it to True if you are running the code for the first time\n", " # GWSNR parameters\n", " sampling_frequency=2048.,\n", " waveform_approximant='IMRPhenomXPHM', # spin-precessing waveform model\n", " minimum_frequency=20.,\n", " psds={'L1':'aLIGO_O4_high_asd.txt'}, # chosen interferometer is 'L1'. If multiple interferometers are chosen, optimal network SNR will be considered.\n", " spin_zero=False,\n", " spin_precessing=True,\n", " snr_method='inner_product', # 'interpolation' or 'inner_product'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* `ler` package, by default, generates astrophysical signals that most likely will not be detected by the detector, i.e. low SNR signals.\n", "\n", "* But you want your ANN model to be sensitive to the signals that near the detection threshold.\n", "\n", "* So, I will generate most of the training data with SNR near the detection threshold.\n", "\n", "**Note:** Increase sample size of the training data to get better accuracy in the ANN model." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWRATES class...\n", "\n", "current size of the json file: 1197\n", "\n", "total event to collect: 10000\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 369/369 [00:00<00:00, 378.22it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 1485\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 378/378 [00:00<00:00, 398.39it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 1773\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 314/314 [00:00<00:00, 354.46it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 2043\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 423/423 [00:01<00:00, 399.50it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 2412\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 338/338 [00:00<00:00, 362.23it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 2691\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 454/454 [00:01<00:00, 409.53it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 3006\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 449/449 [00:01<00:00, 400.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 3348\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 279/279 [00:00<00:00, 348.12it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 3582\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 396/396 [00:01<00:00, 374.89it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 3915\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 387/387 [00:00<00:00, 402.89it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 4257\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 395/395 [00:01<00:00, 383.34it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 4545\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 377/377 [00:01<00:00, 366.50it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 4833\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 377/377 [00:00<00:00, 389.85it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 5166\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 396/396 [00:01<00:00, 389.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 5472\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 405/405 [00:01<00:00, 378.78it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 5769\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 387/387 [00:01<00:00, 382.65it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 6138\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 259/259 [00:00<00:00, 339.38it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 6336\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 315/315 [00:00<00:00, 353.23it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 6588\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 378/378 [00:01<00:00, 363.57it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 6849\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 297/297 [00:00<00:00, 342.76it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 7092\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 305/305 [00:00<00:00, 361.80it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 7317\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 332/332 [00:00<00:00, 376.09it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 7560\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 377/377 [00:00<00:00, 388.24it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 7875\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 306/306 [00:00<00:00, 348.02it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 8127\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 421/421 [00:01<00:00, 394.64it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 8424\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 377/377 [00:01<00:00, 371.58it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 8775\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 432/432 [00:01<00:00, 393.02it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 9144\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 304/304 [00:00<00:00, 334.92it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 9396\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 466/466 [00:01<00:00, 438.22it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 9765\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 369/369 [00:00<00:00, 390.73it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 10098\n", "final size: 10098\n", "\n", "json file saved at: ./ler_data/IMRPhenomXPHM_O4_high_asd_L1_1.json\n", "\n" ] } ], "source": [ "# rerun if hanged\n", "ler = tdg.gw_parameters_generator(\n", " size=10000, # number of samples to generate\n", " batch_size=400000, # reduce this number if you have memory issues\n", " snr_recalculation=True, # pick SNR generated with 'interpolation'; recalculate SNR using 'inner product'\n", " trim_to_size=False, verbose=True,\n", " data_distribution_range = [0., 2., 4., 6., 8., 10., 12., 14., 16., 100.], # equal data samples will be distributed in these ranges\n", " replace=False, # set to True if you want to replace the existing data\n", " output_jsonfile=\"IMRPhenomXPHM_O4_high_asd_L1_1.json\",\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWRATES class...\n", "\n", "total event to collect: 5000\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 438/438 [00:01<00:00, 426.01it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 376\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 494/494 [00:01<00:00, 444.68it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 816\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 480/480 [00:01<00:00, 417.20it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 1224\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 464/464 [00:01<00:00, 409.93it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 1630\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 482/482 [00:01<00:00, 392.78it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 2060\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 470/470 [00:01<00:00, 387.19it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 2460\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 472/472 [00:01<00:00, 412.21it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 2856\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 470/470 [00:01<00:00, 423.07it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 3272\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 530/530 [00:01<00:00, 430.88it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 3712\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 438/438 [00:01<00:00, 405.48it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 4092\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 486/486 [00:01<00:00, 419.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 4516\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 466/466 [00:01<00:00, 379.73it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 4922\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████████| 498/498 [00:01<00:00, 441.36it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 5362\n", "final size: 5362\n", "\n", "json file saved at: ./ler_data/IMRPhenomXPHM_O4_high_asd_L1_2.json\n", "\n" ] } ], "source": [ "# might take 2mins~3mins\n", "# 10 mins 0.7 s, 10000 samples with 8 processes and batch_size=200000 \n", "tdg.gw_parameters_generator(\n", " size=5000, \n", " batch_size=200000, \n", " snr_recalculation=True,\n", " trim_to_size=False, verbose=True, \n", " data_distribution_range = [4., 8., 12.], # equal data samples will be distributed in these ranges\n", " replace=False,\n", " output_jsonfile=\"IMRPhenomXPHM_O4_high_asd_L1_2.json\",\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWRATES class...\n", "\n", "total event to collect: 10000\n", "\n", "final size: 10000\n", "\n", "json file saved at: ./ler_data/IMRPhenomXPHM_O4_high_asd_L1_3.json\n", "\n" ] } ], "source": [ "tdg.gw_parameters_generator(\n", " size=10000,\n", " batch_size=10000, \n", " snr_recalculation=True,\n", " trim_to_size=False, \n", " verbose=False, \n", " data_distribution_range = None,\n", " replace=True,\n", " output_jsonfile=\"IMRPhenomXPHM_O4_high_asd_L1_3.json\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Additional random samples" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "\n", "Chosen GWSNR initialization parameters:\n", "\n", "npool: 8\n", "snr type: inner_product\n", "waveform approximant: IMRPhenomXPHM\n", "sampling frequency: 2048.0\n", "minimum frequency (fmin): 20.0\n", "mtot=mass1+mass2\n", "min(mtot): 9.96\n", "max(mtot) (with the given fmin=20.0): 235.0\n", "detectors: ['L1']\n", "psds: [PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ler/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt')]\n", "\n", "\n" ] } ], "source": [ "from gwsnr import GWSNR\n", "import numpy as np\n", "\n", "gwsnr = GWSNR(\n", " npool=8, # number of processes\n", " # GWSNR parameters\n", " sampling_frequency=2048.,\n", " waveform_approximant='IMRPhenomXPHM', # spin-precessing waveform model\n", " minimum_frequency=20.,\n", " psds={'L1':'aLIGO_O4_high_asd.txt'}, # chosen interferometer is 'L1'. If multiple network SNR will be considered.\n", " snr_method='inner_product', # 'interpolation' or 'inner_product'\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "solving SNR with inner product\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████████████████████| 50000/50000 [00:40<00:00, 1247.30it/s]\n" ] } ], "source": [ "# gerneral case, random parameters\n", "np.random.seed(64)\n", "nsamples = 50000\n", "mtot = np.random.uniform(2*4.98, 2*112.5,nsamples)\n", "mass_ratio = np.random.uniform(0.2,1,size=nsamples)\n", "param_dict = dict(\n", " # convert to component masses\n", " mass_1 = mtot / (1 + mass_ratio),\n", " mass_2 = mtot * mass_ratio / (1 + mass_ratio),\n", " # Fix luminosity distance\n", " luminosity_distance = np.random.uniform(40, 10000, size=nsamples), # Random luminosity distance between 40 and 10000 Mpc\n", " # Randomly sample everything else:\n", " theta_jn = np.random.uniform(0,2*np.pi, size=nsamples),\n", " ra = np.random.uniform(0,2*np.pi, size=nsamples), \n", " dec = np.random.uniform(-np.pi/2,np.pi/2, size=nsamples), \n", " psi = np.random.uniform(0,2*np.pi, size=nsamples),\n", " phase = np.random.uniform(0,2*np.pi, size=nsamples),\n", " geocent_time = 1246527224.169434*np.ones(nsamples),\n", " # spin zero\n", " a_1 = np.random.uniform(0.0,0.8, size=nsamples),\n", " a_2 = np.random.uniform(0.0,0.8, size=nsamples),\n", " tilt_1 = np.random.uniform(0, np.pi, size=nsamples), # tilt angle of the primary black hole in radians\n", " tilt_2 = np.random.uniform(0, np.pi, size=nsamples),\n", " phi_12 = np.random.uniform(0, 2*np.pi, size=nsamples), # Relative angle between the primary and secondary spin of the binary in radians\n", " phi_jl = np.random.uniform(0, 2*np.pi, size=nsamples), # Angle between the total angular momentum and the orbital angular momentum in radians\n", ")\n", "\n", "snrs_ = gwsnr.optimal_snr(gw_param_dict=param_dict)\n", "# time: 0.2 s for 50000 samples with 8 processes" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "param_dict.update(snrs_)\n", "from gwsnr.utils import append_json\n", "append_json(\n", " file_name=\"ler_data/IMRPhenomXPHM_O4_high_asd_L1_4.json\", \n", " new_dictionary =param_dict,\n", " replace=True, # set to True if you want to replace the existing data\n", ");" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Combine all the data files into one" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## L1 detector" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "json file saved at: ./ler_data/IMRPhenomXPHM_O4_high_asd_L1.json\n", "\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from ler.utils import TrainingDataGenerator\n", "\n", "tdg = TrainingDataGenerator()\n", "tdg.combine_dicts(\n", " file_name_list=[\"IMRPhenomXPHM_O4_high_asd_L1_1.json\", \"IMRPhenomXPHM_O4_high_asd_L1_2.json\", \"IMRPhenomXPHM_O4_high_asd_L1_3.json\", \"IMRPhenomXPHM_O4_high_asd_L1_4.json\"],\n", " detector='L1',\n", " output_jsonfile=\"IMRPhenomXPHM_O4_high_asd_L1.json\",\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# from gwsnr.utils import get_param_from_json\n", "# test1 = get_param_from_json(\"./ler_data/IMRPhenomXPHM_O4_high_asd_L1.json\")\n", "\n", "# snr = np.array(test1['L1'])\n", "# print(f\"Number of samples: {len(snr)}\")\n", "\n", "# plt.figure(figsize=[4,4])\n", "# plt.hist(snr, bins=100, density=True, alpha=0.5, color='b', histtype='step', label='L1')\n", "# plt.xlim([0, 40])\n", "# plt.xlabel('Optimal SNR')\n", "# plt.ylabel('Density')\n", "# plt.legend()\n", "# plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### ANN model training and testing" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from gwsnr.ann import ANNModelGenerator" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "Interpolator will be loaded for L1 detector from ./interpolator_pickle/L1/partialSNR_dict_1.pickle\n", "\n", "\n" ] } ], "source": [ "amg = ANNModelGenerator(\n", " directory='./ann_data',\n", " npool=8,\n", " gwsnr_verbose=False,\n", " snr_th=8.0,\n", " waveform_approximant=\"IMRPhenomXPHM\",\n", " psds={'L1': 'aLIGO_O4_high_asd.txt'}, \n", ")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 309us/step - accuracy: 7.4186e-04 - loss: 1195.4056\n", "Epoch 2/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 303us/step - accuracy: 3.7843e-04 - loss: 786.9496\n", "Epoch 3/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 338us/step - accuracy: 5.5559e-04 - loss: 773.3979\n", "Epoch 4/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 4.7510e-04 - loss: 870.4972\n", "Epoch 5/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 302us/step - accuracy: 5.0608e-04 - loss: 669.6340\n", "Epoch 6/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 340us/step - accuracy: 2.4172e-04 - loss: 599.0781\n", "Epoch 7/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 296us/step - accuracy: 1.1708e-04 - loss: 615.2960\n", "Epoch 8/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 296us/step - accuracy: 4.0820e-04 - loss: 504.8292\n", "Epoch 9/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 325us/step - accuracy: 3.1350e-04 - loss: 473.1070\n", "Epoch 10/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 318us/step - accuracy: 4.6338e-04 - loss: 455.0616\n", "Epoch 11/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 3.6646e-04 - loss: 480.0711\n", "Epoch 12/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 331us/step - accuracy: 6.4646e-04 - loss: 430.3829\n", "Epoch 13/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 331us/step - accuracy: 4.6718e-04 - loss: 326.0745\n", "Epoch 14/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 321us/step - accuracy: 4.8356e-04 - loss: 344.0058\n", "Epoch 15/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 333us/step - accuracy: 6.9886e-04 - loss: 318.9627\n", "Epoch 16/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 360us/step - accuracy: 6.3820e-04 - loss: 316.1301\n", "Epoch 17/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 327us/step - accuracy: 6.0792e-04 - loss: 301.1277\n", "Epoch 18/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 315us/step - accuracy: 6.6761e-04 - loss: 235.1516\n", "Epoch 19/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 315us/step - accuracy: 6.2474e-04 - loss: 272.7290\n", "Epoch 20/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 314us/step - accuracy: 7.3992e-04 - loss: 290.7437\n", "Epoch 21/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 4.9990e-04 - loss: 207.5985\n", "Epoch 22/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 319us/step - accuracy: 8.0967e-04 - loss: 263.7404\n", "Epoch 23/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 7.1809e-04 - loss: 235.5507\n", "Epoch 24/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 309us/step - accuracy: 6.9704e-04 - loss: 171.0875\n", "Epoch 25/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 314us/step - accuracy: 9.5122e-04 - loss: 188.7950\n", "Epoch 26/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 318us/step - accuracy: 9.8965e-04 - loss: 200.5005\n", "Epoch 27/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 7.1190e-04 - loss: 219.7502\n", "Epoch 28/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 8.6508e-04 - loss: 195.2021\n", "Epoch 29/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0010 - loss: 162.9744\n", "Epoch 30/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316us/step - accuracy: 8.3644e-04 - loss: 163.9187\n", "Epoch 31/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 322us/step - accuracy: 9.6832e-04 - loss: 186.6651\n", "Epoch 32/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 350us/step - accuracy: 0.0010 - loss: 160.6446\n", "Epoch 33/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 318us/step - accuracy: 0.0011 - loss: 110.9650\n", "Epoch 34/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316us/step - accuracy: 0.0011 - loss: 148.0654\n", "Epoch 35/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 320us/step - accuracy: 0.0013 - loss: 118.0238\n", "Epoch 36/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 330us/step - accuracy: 0.0014 - loss: 119.6208\n", "Epoch 37/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 315us/step - accuracy: 0.0013 - loss: 123.7099\n", "Epoch 38/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 0.0014 - loss: 150.7826\n", "Epoch 39/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 0.0014 - loss: 132.0962\n", "Epoch 40/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 0.0015 - loss: 112.5593\n", "Epoch 41/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 326us/step - accuracy: 0.0012 - loss: 108.4442\n", "Epoch 42/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 318us/step - accuracy: 0.0015 - loss: 131.0688\n", "Epoch 43/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 0.0015 - loss: 130.2682\n", "Epoch 44/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 310us/step - accuracy: 0.0018 - loss: 92.5214\n", "Epoch 45/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 0.0015 - loss: 118.0394\n", "Epoch 46/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 321us/step - accuracy: 0.0017 - loss: 87.9507\n", "Epoch 47/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0016 - loss: 83.3415\n", "Epoch 48/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 346us/step - accuracy: 0.0014 - loss: 114.2209\n", "Epoch 49/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0016 - loss: 56.4827\n", "Epoch 50/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0015 - loss: 73.1249\n", "Epoch 51/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 320us/step - accuracy: 0.0015 - loss: 64.7108\n", "Epoch 52/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 0.0016 - loss: 80.9203\n", "Epoch 53/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 0.0018 - loss: 73.7451\n", "Epoch 54/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 310us/step - accuracy: 0.0012 - loss: 81.1201\n", "Epoch 55/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 322us/step - accuracy: 0.0017 - loss: 79.5091\n", "Epoch 56/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0015 - loss: 49.4225\n", "Epoch 57/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 315us/step - accuracy: 0.0018 - loss: 74.7164\n", "Epoch 58/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 0.0019 - loss: 70.6322\n", "Epoch 59/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 0.0017 - loss: 56.8413\n", "Epoch 60/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 325us/step - accuracy: 0.0018 - loss: 56.4192\n", "Epoch 61/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 319us/step - accuracy: 0.0015 - loss: 45.2987\n", "Epoch 62/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 342us/step - accuracy: 0.0017 - loss: 69.9411\n", "Epoch 63/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316us/step - accuracy: 0.0017 - loss: 79.2437\n", "Epoch 64/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 324us/step - accuracy: 0.0015 - loss: 65.2600\n", "Epoch 65/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0018 - loss: 82.6208\n", "Epoch 66/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 0.0019 - loss: 33.4129\n", "Epoch 67/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 309us/step - accuracy: 0.0014 - loss: 51.1544\n", "Epoch 68/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 311us/step - accuracy: 0.0015 - loss: 41.7268\n", "Epoch 69/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 319us/step - accuracy: 0.0015 - loss: 47.2896\n", "Epoch 70/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 351us/step - accuracy: 0.0015 - loss: 53.9465\n", "Epoch 71/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 314us/step - accuracy: 0.0020 - loss: 55.2872\n", "Epoch 72/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 310us/step - accuracy: 0.0014 - loss: 37.5868\n", "Epoch 73/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 314us/step - accuracy: 0.0016 - loss: 48.0865\n", "Epoch 74/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 321us/step - accuracy: 0.0018 - loss: 59.0398\n", "Epoch 75/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0014 - loss: 80.2237\n", "Epoch 76/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 0.0016 - loss: 43.6412\n", "Epoch 77/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 0.0015 - loss: 40.7872\n", "Epoch 78/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 310us/step - accuracy: 0.0014 - loss: 40.5094\n", "Epoch 79/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 318us/step - accuracy: 0.0016 - loss: 53.2459\n", "Epoch 80/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 315us/step - accuracy: 0.0019 - loss: 46.4946\n", "Epoch 81/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0018 - loss: 39.9218\n", "Epoch 82/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 318us/step - accuracy: 0.0017 - loss: 35.7595\n", "Epoch 83/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 326us/step - accuracy: 0.0017 - loss: 36.7095\n", "Epoch 84/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 371us/step - accuracy: 0.0020 - loss: 30.4304\n", "Epoch 85/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 320us/step - accuracy: 0.0021 - loss: 34.6025\n", "Epoch 86/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 332us/step - accuracy: 0.0020 - loss: 33.1985\n", "Epoch 87/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 333us/step - accuracy: 0.0020 - loss: 25.5399\n", "Epoch 88/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 352us/step - accuracy: 0.0016 - loss: 40.3483\n", "Epoch 89/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 322us/step - accuracy: 0.0016 - loss: 38.8713\n", "Epoch 90/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 312us/step - accuracy: 0.0016 - loss: 36.6795\n", "Epoch 91/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 314us/step - accuracy: 0.0015 - loss: 35.2267\n", "Epoch 92/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316us/step - accuracy: 0.0018 - loss: 26.4520\n", "Epoch 93/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0016 - loss: 27.2608\n", "Epoch 94/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 329us/step - accuracy: 0.0017 - loss: 33.0001\n", "Epoch 95/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 313us/step - accuracy: 0.0019 - loss: 33.2583\n", "Epoch 96/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316us/step - accuracy: 0.0017 - loss: 28.9015\n", "Epoch 97/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 328us/step - accuracy: 0.0019 - loss: 24.7544\n", "Epoch 98/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 331us/step - accuracy: 0.0018 - loss: 27.0456\n", "Epoch 99/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 316us/step - accuracy: 0.0018 - loss: 27.1680\n", "Epoch 100/100\n", "\u001b[1m2123/2123\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 351us/step - accuracy: 0.0014 - loss: 33.9702\n", "\u001b[1m236/236\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 297us/step\n", "scaler saved at: ./ann_data/scaler_L1.pkl\n", "model saved at: ./ann_data/ann_model_L1.h5\n", "error adjustment saved at: ./ann_data/error_adjustment_L1.json\n", "ann path dict saved at: ./ann_data/ann_path_dict.json\n" ] } ], "source": [ "amg.ann_model_training(\n", " gw_param_dict='ler_data/IMRPhenomXPHM_O4_high_asd_L1.json', # you can also get the dict from a json file first\n", " randomize=True,\n", " test_size=0.1,\n", " random_state=42,\n", " num_nodes_list = [5, 32, 32, 1],\n", " activation_fn_list = ['relu', 'relu', 'sigmoid', 'linear'],\n", " optimizer='adam',\n", " loss='mean_squared_error',\n", " metrics=['accuracy'],\n", " batch_size=32,\n", " epochs=100,\n", " error_adjustment_snr_range=[2,14],\n", " ann_file_name = 'ann_model_L1.h5',\n", " scaler_file_name = 'scaler_L1.pkl',\n", " error_adjustment_file_name='error_adjustment_L1.json',\n", " ann_path_dict_file_name='ann_path_dict.json',\n", ")\n", "\n", "# # Uncomment the following, if you have already trained the model\n", "# # load the trained model\n", "# amg.load_model_scaler_error(\n", "# ann_file_name='ann_model_L1.h5', \n", "# scaler_file_name='scaler_L1.pkl',\n", "# error_adjustment_file_name='error_adjustment_L1.json',\n", "# )" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m236/236\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 242us/step\n", "Error: 3.45%\n" ] }, { "data": { "text/plain": [ "(3.445534057778956,\n", " array([14.857255 , 0.5539622, 2.7361565, ..., 12.17969 ,\n", " 30.18567 , 30.088642 ], dtype=float32),\n", " array([14.48240752, 0.67263087, 3.5234792 , ..., 12.18434598,\n", " 30.22166581, 30.23929325]))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "amg.pdet_error()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m236/236\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 243us/step\n", "[[5338 85]\n", " [ 195 1928]]\n", "Accuracy: 96.289%\n" ] }, { "data": { "text/plain": [ "(array([[5338, 85],\n", " [ 195, 1928]]),\n", " 96.28942486085343,\n", " array([ True, False, False, ..., True, True, True]),\n", " array([ True, False, False, ..., True, True, True]))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "amg.pdet_confusion_matrix()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m2359/2359\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 213us/step\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# predicted snr\n", "pred_snr= amg.predict_snr(gw_param_dict='ler_data/IMRPhenomXPHM_O4_high_asd_L1.json')\n", "# true snr\n", "true_snr = amg.get_parameters(gw_param_dict='ler_data/IMRPhenomXPHM_O4_high_asd_L1.json')['L1']\n", "# select only snr between 4 and 12\n", "snr_min = 4\n", "snr_max = 12\n", "mask = (true_snr >= snr_min) & (true_snr <= snr_max)\n", "true_snr = true_snr[mask]\n", "pred_snr = pred_snr[mask]\n", "\n", "# plot the predicted snr vs true snr\n", "plt.figure(figsize=[4,4])\n", "plt.scatter(true_snr, pred_snr, s=1)\n", "snr_lim = [np.min([true_snr, true_snr]), np.max([true_snr, true_snr])]\n", "plt.plot(snr_lim, snr_lim, 'r--')\n", "plt.xlabel('True SNR')\n", "plt.ylabel('Predicted SNR')\n", "plt.xlim([snr_min, snr_max])\n", "plt.ylim([snr_min, snr_max])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[1m2359/2359\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 222us/step\n", "[[53130 947]\n", " [ 1464 19919]]\n", "0.9680492976411343\n" ] } ], "source": [ "# use the following function to predict the pdet\n", "pred_pdet = amg.predict_pdet(gw_param_dict='ler_data/IMRPhenomXPHM_O4_high_asd_L1.json', snr_threshold=8.0)\n", "\n", "true_snr = amg.get_parameters(gw_param_dict='ler_data/IMRPhenomXPHM_O4_high_asd_L1.json')['L1']\n", "# true pdet\n", "true_pdet = np.array([1 if snr >= 8.0 else 0 for snr in true_snr])\n", "\n", "from sklearn.metrics import confusion_matrix, accuracy_score\n", "cm = confusion_matrix(true_pdet, pred_pdet)\n", "print(cm)\n", "\n", "acc = accuracy_score(true_pdet, pred_pdet)\n", "print(acc)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "markdown" } }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Implementation of the ANN model in GWSNR\n", "\n", "Generate new astrophysical data and test the model on it using GWSNR class." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWRATES class...\n", "\n", "total event to collect: 20000\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|████████████████████████████████████████████████████████| 19507/19507 [00:26<00:00, 724.89it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collected number of events: 20000\n", "final size: 20000\n", "\n", "json file saved at: ./ler_data/IMRPhenomXPHM_O4_high_asd_L1_5.json\n", "\n" ] } ], "source": [ "from ler.utils import TrainingDataGenerator\n", "\n", "# generate some new data\n", "tdg = TrainingDataGenerator(\n", " npool=4,\n", " verbose=False, \n", " # GWSNR parameters\n", " sampling_frequency=2048,\n", " waveform_approximant='IMRPhenomXPHM',\n", " psds={'L1': 'aLIGO_O4_high_asd.txt'}, \n", " minimum_frequency=20,\n", " spin_zero=False,\n", " spin_precessing=True,\n", " snr_method='inner_product',\n", ")\n", "\n", "tdg.gw_parameters_generator(\n", " size=20000, \n", " batch_size=20000, \n", " snr_recalculation=False,\n", " trim_to_size=False, \n", " verbose=True, \n", " data_distribution_range = None, \n", " replace=False,\n", " output_jsonfile=\"IMRPhenomXPHM_O4_high_asd_L1_5.json\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* using GWSNR class, with the trained ANN model, you can generate SNR of the astrophysical GW signal parameters" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Initializing GWSNR class...\n", "\n", "Intel processor has trouble allocating memory when the data is huge. So, by default for IMRPhenomXPHM, duration_max = 64.0. Otherwise, set to some max value like duration_max = 600.0 (10 mins)\n", "ANN model and scaler path is given. Using the given path.\n", "ANN model for L1 is loaded from ./ann_data/ann_model_L1.h5.\n", "ANN scaler for L1 is loaded from ./ann_data/scaler_L1.pkl.\n", "ANN error_adjustment for L1 is loaded from ./ann_data/error_adjustment_L1.json.\n", "Interpolator will be loaded for L1 detector from ./interpolator_pickle/L1/partialSNR_dict_1.pickle\n", "\n", "Chosen GWSNR initialization parameters:\n", "\n", "npool: 8\n", "snr type: ann\n", "waveform approximant: IMRPhenomXPHM\n", "sampling frequency: 2048.0\n", "minimum frequency (fmin): 20.0\n", "mtot=mass1+mass2\n", "min(mtot): 9.96\n", "max(mtot) (with the given fmin=20.0): 235.0\n", "detectors: ['L1']\n", "psds: [PowerSpectralDensity(psd_file='None', asd_file='/Users/phurailatpamhemantakumar/anaconda3/envs/ler/lib/python3.10/site-packages/bilby/gw/detector/noise_curves/aLIGO_O4_high_asd.txt')]\n", "\n", "\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from gwsnr import GWSNR\n", "\n", "gwsnr = GWSNR(\n", " snr_method='ann',\n", " npool=8, # number of processes\n", " waveform_approximant=\"IMRPhenomXPHM\",\n", " psds={'L1': 'aLIGO_O4_high_asd.txt'},\n", " ann_path_dict='./ann_data/ann_path_dict.json',\n", ")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# predicted snr, using ANN model \n", "pred_snr = gwsnr.optimal_snr_with_ann(gw_param_dict='./ler_data/IMRPhenomXPHM_O4_high_asd_L1_5.json')['L1']#['snr_net']" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from gwsnr.utils import get_param_from_json\n", "true_snr = get_param_from_json('./ler_data/IMRPhenomXPHM_O4_high_asd_L1_5.json')['L1']#['snr_net']" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# select only snr between 4 and 12\n", "# snr_min = 4\n", "# snr_max = 12\n", "# mask = (true_snr >= snr_min) & (true_snr <= snr_max)\n", "# true_snr = true_snr[mask]\n", "# pred_snr = pred_snr[mask]\n", "\n", "# plot the predicted snr vs true snr\n", "plt.figure(figsize=[4,4])\n", "plt.scatter(true_snr, pred_snr, s=1)\n", "snr_lim = [np.min([true_snr, true_snr]), np.max([true_snr, true_snr])]\n", "plt.plot(snr_lim, snr_lim, 'r--')\n", "plt.xlabel('True SNR')\n", "plt.ylabel('Predicted SNR')\n", "# plt.xlim([snr_min, snr_max])\n", "# plt.ylim([snr_min, snr_max])\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[19966 0]\n", " [ 7 27]]\n", "0.99965\n" ] } ], "source": [ "# use the following function to predict the pdet\n", "pred_pdet = np.array([1 if snr >= 8.0 else 0 for snr in pred_snr])\n", "# true pdet\n", "true_pdet = np.array([1 if snr >= 8.0 else 0 for snr in true_snr])\n", "\n", "from sklearn.metrics import confusion_matrix, accuracy_score\n", "cm = confusion_matrix(true_pdet, pred_pdet)\n", "print(cm)\n", "\n", "acc = accuracy_score(true_pdet, pred_pdet)\n", "print(acc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "ler", "language": "python", "name": "ler" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 2 }