import matplotlib.pyplot as plt
import xarray as xr
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.impute import SimpleImputer
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
from tensorflow.keras.utils import Sequence
In [20]:
In [21]:
# Load your dataset
= xr.open_dataset('S5PL2_5D.nc')
ds ds
<xarray.Dataset> Size: 2GB Dimensions: (time: 366, lat: 291, lon: 512, bnds: 2) Coordinates: * lat (lat) float64 2kB 34.36 34.33 34.3 ... 24.97 24.94 24.9 * lon (lon) float64 4kB 68.15 68.19 68.22 ... 84.75 84.79 84.82 * time (time) datetime64[ns] 3kB 2019-01-03T12:00:00 ... 2024-01... time_bnds (time, bnds) datetime64[ns] 6kB ... Dimensions without coordinates: bnds Data variables: AER_AI_340_380 (time, lat, lon) float32 218MB ... AER_AI_354_388 (time, lat, lon) float32 218MB ... CH4 (time, lat, lon) float32 218MB ... CLOUD_FRACTION (time, lat, lon) float32 218MB ... CO (time, lat, lon) float32 218MB ... HCHO (time, lat, lon) float32 218MB ... NO2 (time, lat, lon) float32 218MB ... O3 (time, lat, lon) float32 218MB ... SO2 (time, lat, lon) float32 218MB ... Attributes: Conventions: CF-1.7 title: S5PL2 Data Cube Subset history: [{'program': 'xcube_sh.chunkstore.SentinelHubC... date_created: 2024-05-02T13:00:01.155492 time_coverage_start: 2019-01-01T00:00:00+00:00 time_coverage_end: 2024-01-05T00:00:00+00:00 time_coverage_duration: P1830DT0H0M0S time_coverage_resolution: P5DT0H0M0S geospatial_lon_min: 68.137207 geospatial_lat_min: 24.886436 geospatial_lon_max: 84.836426 geospatial_lat_max: 34.37759367382812
In [23]:
# Stack the features into a single DataArray
= ['SO2', 'NO2', 'CH4', 'O3', 'CO', 'HCHO']
features = xr.concat([ds[feature] for feature in features], dim='feature')
data = data.transpose('time', 'lat', 'lon', 'feature')
data
# Convert to NumPy arrays
= data.values.astype(np.float32)
X_data
# Normalize the input data
= StandardScaler()
scaler = scaler.fit_transform(X_data.reshape(-1, X_data.shape[-1])).reshape(X_data.shape)
X_data
# Impute nan values with the mean of the respective feature
= X_data.reshape(-1, X_data.shape[-1])
X_data_reshaped = SimpleImputer(strategy='mean')
imputer = imputer.fit_transform(X_data_reshaped)
X_data_imputed = X_data_imputed.reshape(X_data.shape)
X_data_imputed
# Add the time dimension to the input data
= np.expand_dims(X_data_imputed, axis=1)
X_data_imputed
# Load your actual target data
= xr.open_dataset('S5PL2_5D.nc')
ds_target = ds_target['AER_AI_340_380'].values.astype(np.float32)
target_data
# Normalize target data to [0, 1]
= MinMaxScaler()
target_scaler = target_scaler.fit_transform(target_data.reshape(-1, 1)).reshape(target_data.shape)
target_data
# Impute nan values in target data
= target_data.reshape(-1, target_data.shape[-1])
target_data_reshaped = imputer.fit_transform(target_data_reshaped)
target_data_imputed = target_data_imputed.reshape(target_data.shape)
target_data_imputed
# Ensure the target data shape is (num_samples, num_timesteps, num_latitudes, num_longitudes, 1)
= target_data_imputed.reshape((target_data.shape[0], 1, target_data.shape[1], target_data.shape[2], 1))
target_data_imputed
# Remove samples with nan values in target data
= ~np.isnan(target_data_imputed).any(axis=(1, 2, 3, 4))
non_nan_target_indices = X_data_imputed[non_nan_target_indices]
X_data_clean = target_data_imputed[non_nan_target_indices]
y_data_clean
# Ensure target values are within the valid range [0, 1]
print("Target data range: ", y_data_clean.max(), y_data_clean.min())
# Split data into training and validation sets
= 0.8
split_ratio = int(split_ratio * X_data_clean.shape[0])
split_idx
= X_data_clean[:split_idx], X_data_clean[split_idx:]
X_train, X_val = y_data_clean[:split_idx], y_data_clean[split_idx:] y_train, y_val
Target data range: 1.0 0.0
In [19]:
print("Input shape:", X_data_clean.shape)
print("Target shape:", y_data_clean.shape)
print("Max and Min of Target Data:", y_data_clean.max(), y_data_clean.min())
Input shape: (366, 1, 291, 512, 6)
Target shape: (366, 1, 291, 512, 1)
Max and Min of Target Data: 1.0 0.0
In [20]:
print("Training data shape:", X_train.shape)
print("Validation data shape:", X_val.shape)
print("Training target shape:", y_train.shape)
print("Validation target shape:", y_val.shape)
Training data shape: (292, 1, 291, 512, 6)
Validation data shape: (74, 1, 291, 512, 6)
Training target shape: (292, 1, 291, 512, 1)
Validation target shape: (74, 1, 291, 512, 1)
In [21]:
import matplotlib.pyplot as plt
# Visualize a sample from the training data
= 0
sample_index 0, :, :, 2], cmap='viridis')
plt.imshow(X_train[sample_index, 'Sample Input')
plt.title(
plt.show()
0, :, :, 0], cmap='viridis')
plt.imshow(y_train[sample_index, 'Sample Target')
plt.title( plt.show()
In [31]:
"X_val.npy", X_val)
np.save("Y_val.npy", y_val) np.save(
In [23]:
# Define the model with correct input shape
= layers.Input(shape=(1, X_data_clean.shape[2], X_data_clean.shape[3], X_data_clean.shape[4]))
inp
= layers.BatchNormalization()(inp)
x = layers.ConvLSTM2D(
x =16,
filters=(3, 3),
kernel_size="same",
padding=True,
return_sequences="tanh",
activation="sigmoid",
recurrent_activation="glorot_uniform"
kernel_initializer
)(x)= layers.BatchNormalization()(x)
x = layers.ConvLSTM2D(
x =32,
filters=(3, 3),
kernel_size="same",
padding=True,
return_sequences="tanh",
activation="sigmoid",
recurrent_activation="glorot_uniform"
kernel_initializer
)(x)= layers.BatchNormalization()(x)
x = layers.Conv3D(
x =1, kernel_size=(3, 3, 3), activation="sigmoid", padding="same"
filters
)(x)
= keras.models.Model(inp, x, name="smogseer")
model
# Use a reduced learning rate and gradient clipping
= keras.optimizers.Adam(learning_rate=1e-5, clipnorm=1.0)
optimizer compile(
model.=keras.losses.binary_crossentropy,
loss=optimizer,
optimizer=['mean_squared_error']
metrics
)
# Print the model summary
model.summary()
# Data Generator Class
class DataGenerator(Sequence):
def __init__(self, X_data, y_data, batch_size):
self.X_data = X_data
self.y_data = y_data
self.batch_size = batch_size
self.indices = np.arange(X_data.shape[0])
def __len__(self):
return int(np.ceil(len(self.indices) / self.batch_size))
def __getitem__(self, index):
= self.indices[index * self.batch_size:(index + 1) * self.batch_size]
batch_indices = self.X_data[batch_indices]
batch_X = self.y_data[batch_indices]
batch_y return batch_X, batch_y
def on_epoch_end(self):
self.indices)
np.random.shuffle(
= 1
batch_size = DataGenerator(X_train, y_train, batch_size)
train_generator = DataGenerator(X_val, y_val, batch_size)
val_generator
# Define callbacks for monitoring and adjusting learning rate
= [
callbacks
keras.callbacks.ReduceLROnPlateau(='val_loss', factor=0.1, patience=10, verbose=1, min_lr=1e-7
monitor
),
keras.callbacks.EarlyStopping(='val_loss', patience=15, verbose=1, restore_best_weights=True
monitor
),='./logs')
keras.callbacks.TensorBoard(log_dir
]
# Train the model using data generators
= model.fit(train_generator, epochs=50, validation_data=val_generator, callbacks=callbacks)
history # Save the model
'smogseer.keras') model.save(
Model: "smogseer"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 1, 291, 512, 6) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_3 │ (None, 1, 291, 512, 6) │ 24 │ │ (BatchNormalization) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv_lstm2d_2 (ConvLSTM2D) │ (None, 1, 291, 512, │ 12,736 │ │ │ 16) │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_4 │ (None, 1, 291, 512, │ 64 │ │ (BatchNormalization) │ 16) │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv_lstm2d_3 (ConvLSTM2D) │ (None, 1, 291, 512, │ 55,424 │ │ │ 32) │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ batch_normalization_5 │ (None, 1, 291, 512, │ 128 │ │ (BatchNormalization) │ 32) │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv3d_1 (Conv3D) │ (None, 1, 291, 512, 1) │ 865 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 69,241 (270.47 KB)
Trainable params: 69,133 (270.05 KB)
Non-trainable params: 108 (432.00 B)
Epoch 1/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 150s 496ms/step - loss: 0.7337 - mean_squared_error: 0.1448 - val_loss: 0.6862 - val_mean_squared_error: 0.1305 - learning_rate: 1.0000e-05
Epoch 2/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 137s 469ms/step - loss: 0.7032 - mean_squared_error: 0.1331 - val_loss: 0.6964 - val_mean_squared_error: 0.1351 - learning_rate: 1.0000e-05
Epoch 3/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 473ms/step - loss: 0.6931 - mean_squared_error: 0.1289 - val_loss: 0.7002 - val_mean_squared_error: 0.1363 - learning_rate: 1.0000e-05
Epoch 4/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 472ms/step - loss: 0.6858 - mean_squared_error: 0.1247 - val_loss: 0.6941 - val_mean_squared_error: 0.1334 - learning_rate: 1.0000e-05
Epoch 5/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 473ms/step - loss: 0.6783 - mean_squared_error: 0.1212 - val_loss: 0.6801 - val_mean_squared_error: 0.1267 - learning_rate: 1.0000e-05
Epoch 6/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 478ms/step - loss: 0.6691 - mean_squared_error: 0.1169 - val_loss: 0.6731 - val_mean_squared_error: 0.1232 - learning_rate: 1.0000e-05
Epoch 7/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 141s 483ms/step - loss: 0.6576 - mean_squared_error: 0.1122 - val_loss: 0.6551 - val_mean_squared_error: 0.1144 - learning_rate: 1.0000e-05
Epoch 8/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 147s 503ms/step - loss: 0.6441 - mean_squared_error: 0.1054 - val_loss: 0.6403 - val_mean_squared_error: 0.1071 - learning_rate: 1.0000e-05
Epoch 9/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 149s 510ms/step - loss: 0.6289 - mean_squared_error: 0.0970 - val_loss: 0.6240 - val_mean_squared_error: 0.0991 - learning_rate: 1.0000e-05
Epoch 10/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 144s 492ms/step - loss: 0.6117 - mean_squared_error: 0.0892 - val_loss: 0.6035 - val_mean_squared_error: 0.0891 - learning_rate: 1.0000e-05
Epoch 11/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 142s 485ms/step - loss: 0.5937 - mean_squared_error: 0.0801 - val_loss: 0.5833 - val_mean_squared_error: 0.0794 - learning_rate: 1.0000e-05
Epoch 12/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 137s 469ms/step - loss: 0.5754 - mean_squared_error: 0.0714 - val_loss: 0.5642 - val_mean_squared_error: 0.0702 - learning_rate: 1.0000e-05
Epoch 13/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 472ms/step - loss: 0.5567 - mean_squared_error: 0.0626 - val_loss: 0.5496 - val_mean_squared_error: 0.0634 - learning_rate: 1.0000e-05
Epoch 14/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 473ms/step - loss: 0.5387 - mean_squared_error: 0.0540 - val_loss: 0.5301 - val_mean_squared_error: 0.0543 - learning_rate: 1.0000e-05
Epoch 15/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 137s 469ms/step - loss: 0.5220 - mean_squared_error: 0.0461 - val_loss: 0.5103 - val_mean_squared_error: 0.0453 - learning_rate: 1.0000e-05
Epoch 16/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 142s 486ms/step - loss: 0.5051 - mean_squared_error: 0.0391 - val_loss: 0.4963 - val_mean_squared_error: 0.0390 - learning_rate: 1.0000e-05
Epoch 17/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 147s 503ms/step - loss: 0.4907 - mean_squared_error: 0.0324 - val_loss: 0.4815 - val_mean_squared_error: 0.0326 - learning_rate: 1.0000e-05
Epoch 18/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 147s 502ms/step - loss: 0.4774 - mean_squared_error: 0.0263 - val_loss: 0.4664 - val_mean_squared_error: 0.0263 - learning_rate: 1.0000e-05
Epoch 19/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 146s 501ms/step - loss: 0.4648 - mean_squared_error: 0.0214 - val_loss: 0.4584 - val_mean_squared_error: 0.0230 - learning_rate: 1.0000e-05
Epoch 20/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 142s 485ms/step - loss: 0.4539 - mean_squared_error: 0.0169 - val_loss: 0.4452 - val_mean_squared_error: 0.0177 - learning_rate: 1.0000e-05
Epoch 21/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 475ms/step - loss: 0.4444 - mean_squared_error: 0.0133 - val_loss: 0.4366 - val_mean_squared_error: 0.0144 - learning_rate: 1.0000e-05
Epoch 22/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 472ms/step - loss: 0.4362 - mean_squared_error: 0.0103 - val_loss: 0.4267 - val_mean_squared_error: 0.0108 - learning_rate: 1.0000e-05
Epoch 23/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 473ms/step - loss: 0.4303 - mean_squared_error: 0.0077 - val_loss: 0.4232 - val_mean_squared_error: 0.0096 - learning_rate: 1.0000e-05
Epoch 24/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 475ms/step - loss: 0.4241 - mean_squared_error: 0.0058 - val_loss: 0.4167 - val_mean_squared_error: 0.0073 - learning_rate: 1.0000e-05
Epoch 25/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 145s 497ms/step - loss: 0.4207 - mean_squared_error: 0.0044 - val_loss: 0.4112 - val_mean_squared_error: 0.0054 - learning_rate: 1.0000e-05
Epoch 26/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 147s 502ms/step - loss: 0.4154 - mean_squared_error: 0.0034 - val_loss: 0.4105 - val_mean_squared_error: 0.0052 - learning_rate: 1.0000e-05
Epoch 27/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 146s 499ms/step - loss: 0.4136 - mean_squared_error: 0.0026 - val_loss: 0.4066 - val_mean_squared_error: 0.0040 - learning_rate: 1.0000e-05
Epoch 28/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 146s 500ms/step - loss: 0.4119 - mean_squared_error: 0.0020 - val_loss: 0.4046 - val_mean_squared_error: 0.0033 - learning_rate: 1.0000e-05
Epoch 29/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 476ms/step - loss: 0.4105 - mean_squared_error: 0.0016 - val_loss: 0.4034 - val_mean_squared_error: 0.0030 - learning_rate: 1.0000e-05
Epoch 30/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 474ms/step - loss: 0.4106 - mean_squared_error: 0.0013 - val_loss: 0.4030 - val_mean_squared_error: 0.0029 - learning_rate: 1.0000e-05
Epoch 31/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 472ms/step - loss: 0.4078 - mean_squared_error: 0.0012 - val_loss: 0.4021 - val_mean_squared_error: 0.0026 - learning_rate: 1.0000e-05
Epoch 32/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 471ms/step - loss: 0.4087 - mean_squared_error: 0.0011 - val_loss: 0.4015 - val_mean_squared_error: 0.0024 - learning_rate: 1.0000e-05
Epoch 33/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 475ms/step - loss: 0.4083 - mean_squared_error: 9.7096e-04 - val_loss: 0.4016 - val_mean_squared_error: 0.0024 - learning_rate: 1.0000e-05
Epoch 34/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 141s 483ms/step - loss: 0.4075 - mean_squared_error: 9.5572e-04 - val_loss: 0.4013 - val_mean_squared_error: 0.0023 - learning_rate: 1.0000e-05
Epoch 35/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 475ms/step - loss: 0.4084 - mean_squared_error: 9.2558e-04 - val_loss: 0.4006 - val_mean_squared_error: 0.0021 - learning_rate: 1.0000e-05
Epoch 36/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 137s 469ms/step - loss: 0.4088 - mean_squared_error: 8.9947e-04 - val_loss: 0.4008 - val_mean_squared_error: 0.0022 - learning_rate: 1.0000e-05
Epoch 37/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 477ms/step - loss: 0.4081 - mean_squared_error: 8.2714e-04 - val_loss: 0.4005 - val_mean_squared_error: 0.0021 - learning_rate: 1.0000e-05
Epoch 38/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 147s 504ms/step - loss: 0.4093 - mean_squared_error: 8.5030e-04 - val_loss: 0.4005 - val_mean_squared_error: 0.0021 - learning_rate: 1.0000e-05
Epoch 39/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 146s 500ms/step - loss: 0.4090 - mean_squared_error: 8.4955e-04 - val_loss: 0.4002 - val_mean_squared_error: 0.0020 - learning_rate: 1.0000e-05
Epoch 40/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 148s 507ms/step - loss: 0.4085 - mean_squared_error: 8.1096e-04 - val_loss: 0.4000 - val_mean_squared_error: 0.0019 - learning_rate: 1.0000e-05
Epoch 41/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 140s 481ms/step - loss: 0.4082 - mean_squared_error: 7.9257e-04 - val_loss: 0.4003 - val_mean_squared_error: 0.0020 - learning_rate: 1.0000e-05
Epoch 42/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 473ms/step - loss: 0.4087 - mean_squared_error: 7.7913e-04 - val_loss: 0.4006 - val_mean_squared_error: 0.0021 - learning_rate: 1.0000e-05
Epoch 43/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 139s 474ms/step - loss: 0.4088 - mean_squared_error: 7.6459e-04 - val_loss: 0.3999 - val_mean_squared_error: 0.0019 - learning_rate: 1.0000e-05
Epoch 44/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 473ms/step - loss: 0.4083 - mean_squared_error: 7.5364e-04 - val_loss: 0.4001 - val_mean_squared_error: 0.0020 - learning_rate: 1.0000e-05
Epoch 45/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 473ms/step - loss: 0.4066 - mean_squared_error: 7.8711e-04 - val_loss: 0.3999 - val_mean_squared_error: 0.0019 - learning_rate: 1.0000e-05
Epoch 46/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 145s 495ms/step - loss: 0.4093 - mean_squared_error: 7.6262e-04 - val_loss: 0.3998 - val_mean_squared_error: 0.0019 - learning_rate: 1.0000e-05
Epoch 47/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 147s 504ms/step - loss: 0.4076 - mean_squared_error: 7.3529e-04 - val_loss: 0.3997 - val_mean_squared_error: 0.0019 - learning_rate: 1.0000e-05
Epoch 48/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 145s 496ms/step - loss: 0.4094 - mean_squared_error: 6.9850e-04 - val_loss: 0.3998 - val_mean_squared_error: 0.0019 - learning_rate: 1.0000e-05
Epoch 49/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 147s 503ms/step - loss: 0.4087 - mean_squared_error: 7.1750e-04 - val_loss: 0.3997 - val_mean_squared_error: 0.0019 - learning_rate: 1.0000e-05
Epoch 50/50
292/292 ━━━━━━━━━━━━━━━━━━━━ 138s 472ms/step - loss: 0.4070 - mean_squared_error: 6.8971e-04 - val_loss: 0.3996 - val_mean_squared_error: 0.0018 - learning_rate: 1.0000e-05
Restoring model weights from the end of the best epoch: 50.
In [6]:
## LOAD CHECKPOINTS IF NEEDED
= np.load('X_val.npy')
X_val = np.load('Y_val.npy')
y_val
# Load the model
= load_model('smogseer50.keras')
model
# Run predictions on validation data
= model.predict(X_val)
predictions
# Evaluate the model on validation data
= model.evaluate(X_val, y_val)
val_loss, val_accuracy print(f"Validation Loss: {val_loss}")
print(f"Validation Accuracy: {val_accuracy}")
3/3 ━━━━━━━━━━━━━━━━━━━━ 7s 747ms/step
3/3 ━━━━━━━━━━━━━━━━━━━━ 3s 460ms/step - loss: 0.3992 - mean_squared_error: 0.0016
Validation Loss: 0.39958614110946655
Validation Accuracy: 0.0018263210076838732
In [28]:
# Plot comparisons and training history
def plot_comparison(y_true, y_pred, index, save_path):
"""
Plots the ground truth and the predicted output for a given index.
Parameters:
- y_true: Ground truth data
- y_pred: Predicted data
- index: Index of the sample to plot
- save_path: Path to save the plot
"""
= plt.subplots(1, 2, figsize=(12, 6))
fig, axes
# Plot ground truth
= axes[0]
ax 0, :, :, 0], cmap='viridis')
ax.imshow(y_true[index, 'Ground Truth')
ax.set_title('off')
ax.axis(
# Plot prediction
= axes[1]
ax 0, :, :, 0], cmap='viridis')
ax.imshow(y_pred[index, 'Prediction')
ax.set_title('off')
ax.axis(
plt.tight_layout()
plt.savefig(save_path)
plt.close()
# Visualize a few samples
= 5
num_samples_to_plot for i in range(num_samples_to_plot):
f'comparison_plot_{i}.png')
plot_comparison(y_val, predictions, i,
# Plot training history
def plot_training_history(history, save_path):
"""
Plots the training and validation loss and accuracy over epochs.
Parameters:
- history: Keras History object
- save_path: Path to save the plot
"""
= plt.subplots(1, 2, figsize=(14, 6))
fig, (ax1, ax2)
# Plot loss
'loss'], label='Training Loss')
ax1.plot(history.history['val_loss'], label='Validation Loss')
ax1.plot(history.history['Loss over epochs')
ax1.set_title('Epoch')
ax1.set_xlabel('Loss')
ax1.set_ylabel(
ax1.legend()
# Plot accuracy
'mean_squared_error'], label='Training MSE')
ax2.plot(history.history['val_mean_squared_error'], label='Validation MSE')
ax2.plot(history.history['MSE over epochs')
ax2.set_title('Epoch')
ax2.set_xlabel('MSE')
ax2.set_ylabel(
ax2.legend()
plt.tight_layout()
plt.savefig(save_path)
plt.close()
# Plot training history
'training_history_epoch50.png') plot_training_history(history,
In [24]:
# Visualize a few samples
= 5
num_timesteps_to_plot for i in range(num_timesteps_to_plot):
print(predictions[i, 0, :, :, 0].shape)
(291, 512)
(291, 512)
(291, 512)
(291, 512)
(291, 512)
In [25]:
print("Target data range: ", predictions.max(), predictions.min())
Target data range: 0.99998665 1.6277676e-06