Table of Contents
ToggleWhen experimenting with various parameters in a model like changing the number of layers, activation function, or neurons, sometimes the model doesn’t perform well while training.
We can save training time and other resources by stopping the model when its performance on the validation dataset deteriorates. The process is called early stopping in machine learning. It is a regularization technique that prevents overfitting while training a model based on its performance on a validation dataset.
Implementing early stopping is quite simple in popular deep-learning frameworks such as TensorFlow, Keras, and PyTorch. Below are examples of how to implement early stopping in each of these frameworks.
TensorFlow / Keras (using the EarlyStopping Callback)
In TensorFlow / Keras we have a built-in callback function called “EarlyStopping” that helps us to easily implement early stopping.
Let’s go through the code.
import tensorflow as tf from tensorflow.keras.callbacks import EarlyStopping # Load a sample dataset (X_train, y_train), (X_val, y_val) = tf.keras.datasets.mnist.load_data() X_train, X_val = X_train / 255.0, X_val / 255.0 # Define the model model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10) ]) # Compile the model model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy']) # Set up EarlyStopping callback early_stopping = EarlyStopping(monitor='val_loss', # Metric to monitor patience=3, # Number of epochs to wait before stopping restore_best_weights=True) # Restore the best model's weights # Train the model with EarlyStopping callback history = model.fit(X_train, y_train, epochs=50, validation_data=(X_val, y_val), callbacks=[early_stopping]) # Pass the callback to the fit() function
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step /usr/local/lib/python3.11/dist-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(**kwargs) Epoch 1/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 16s 7ms/step - accuracy: 0.8604 - loss: 0.4732 - val_accuracy: 0.9582 - val_loss: 0.1351 Epoch 2/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 16s 4ms/step - accuracy: 0.9566 - loss: 0.1482 - val_accuracy: 0.9692 - val_loss: 0.1010 Epoch 3/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 9s 4ms/step - accuracy: 0.9684 - loss: 0.1054 - val_accuracy: 0.9742 - val_loss: 0.0841 Epoch 4/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 11s 4ms/step - accuracy: 0.9739 - loss: 0.0871 - val_accuracy: 0.9759 - val_loss: 0.0748 Epoch 5/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.9787 - loss: 0.0703 - val_accuracy: 0.9778 - val_loss: 0.0742 Epoch 6/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.9802 - loss: 0.0636 - val_accuracy: 0.9788 - val_loss: 0.0705 Epoch 7/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - accuracy: 0.9827 - loss: 0.0544 - val_accuracy: 0.9782 - val_loss: 0.0716 Epoch 8/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9840 - loss: 0.0518 - val_accuracy: 0.9787 - val_loss: 0.0695 Epoch 9/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.9845 - loss: 0.0460 - val_accuracy: 0.9790 - val_loss: 0.0694 Epoch 10/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9860 - loss: 0.0417 - val_accuracy: 0.9798 - val_loss: 0.0675 Epoch 11/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9887 - loss: 0.0349 - val_accuracy: 0.9809 - val_loss: 0.0672 Epoch 12/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - accuracy: 0.9882 - loss: 0.0364 - val_accuracy: 0.9794 - val_loss: 0.0707 Epoch 13/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 3ms/step - accuracy: 0.9891 - loss: 0.0318 - val_accuracy: 0.9813 - val_loss: 0.0718 Epoch 14/50 1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9892 - loss: 0.0312 - val_accuracy: 0.9808 - val_loss: 0.0684We can see model training stopped at the 14th epoch. If we carefully see val_loss from epoch 12 to 14, it has increased causing an early stop.
Explanation of Key Parameters of EarlyStopping function in TensorFlow / Keras
- monitor: This metric is usuallyused to monitor ‘val_loss’ or ‘val_accuracy.
- patience: It is the count of epochs to wait for improvement before stopping from training further.
- restore_best_weights: Whether to restore the model weights from the best epoch
Implementing Early stopping in PyTorch
Using a manual method to implement early stopping in PyTorch
In Pytorch we don’t have a built-in early-stopping callback function like Keras. However, we can easily implement early stopping manually in PyTorch. The simple hack is to track the validation loss over the epochs and if it doesn’t improve after a certain number of epochs, we must stop the training
Let’s start coding.
import torch from torch import nn, optim from torch.utils.data import DataLoader, TensorDataset import numpy as np # Sample dataset and model X_train = np.random.rand(1000, 28 * 28).astype(np.float32) y_train = np.random.randint(0, 10, 1000).astype(np.int64) X_val = np.random.rand(200, 28 * 28).astype(np.float32) y_val = np.random.randint(0, 10, 200).astype(np.int64) train_data = TensorDataset(torch.tensor(X_train), torch.tensor(y_train)) val_data = TensorDataset(torch.tensor(X_val), torch.tensor(y_val)) train_loader = DataLoader(train_data, batch_size=64, shuffle=True) val_loader = DataLoader(val_data, batch_size=64) # Define a simple model class SimpleNN(nn.Module): def __init__(self): super(SimpleNN, self).__init__() self.fc1 = nn.Linear(28*28, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = torch.relu(self.fc1(x)) x = self.fc2(x) return x model = SimpleNN() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) #------------------early stopping implementation —--------------------- # Early stopping parameters patience = 3 best_val_loss = float('inf') epochs_without_improvement = 0 # Training loop for epoch in range(50): model.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # Validation phase model.eval() val_loss = 0 with torch.no_grad(): for inputs, labels in val_loader: outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() val_loss /= len(val_loader) print(f'Epoch {epoch+1}, Validation Loss: {val_loss}') # Early stopping condition if val_loss < best_val_loss: best_val_loss = val_loss epochs_without_improvement = 0 # Save the model checkpoint torch.save(model.state_dict(), 'best_model.pth') else: epochs_without_improvement += 1 if epochs_without_improvement >= patience: print("Early stopping...") break
Output
Epoch 1, Validation Loss: 2.345960259437561 Epoch 2, Validation Loss: 2.3512794971466064 Epoch 3, Validation Loss: 2.3335058093070984 Epoch 4, Validation Loss: 2.3214242458343506 Epoch 5, Validation Loss: 2.3485482335090637 Epoch 6, Validation Loss: 2.3570011854171753 Epoch 7, Validation Loss: 2.3309141993522644 Early stopping...
As we can see the model training was stopped as val_loss increased
Key Points:
- Track val_loss or another metric for validation performance.
- Stop training when the metric stops improving for a certain number of epochs (patience).
- Save the best model weights during training and load them after stopping.
Using early-stopping-pytorch library for early stopping in Pytorch
pip install early-stopping-pytorchAnd hit enter. We will be using the model from the above code which we discussed.
from early_stopping_pytorch import EarlyStopping # Initialize early stopping early_stopping = EarlyStopping(patience=5, verbose=True) # Training loop for epoch in range(50): model.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs).squeeze() #squeeze to remove extra dimension loss = criterion(outputs, labels.long()) loss.backward() optimizer.step() # Validation phase model.eval() val_loss = 0 with torch.no_grad(): for inputs, labels in val_loader: outputs = model(inputs).squeeze() loss = criterion(outputs, labels.long()) val_loss += loss.item() val_loss /= len(val_loader) print(f'Epoch {epoch+1}, Validation Loss: {val_loss}') # Apply early stopping early_stopping(val_loss, model) if early_stopping.early_stop: print("Early stopping...") breakOutput
Epoch 1, Validation Loss: 2.566432535648346 Validation loss decreased (inf --> 2.566433). Saving model ... Epoch 2, Validation Loss: 2.4972275495529175 Validation loss decreased (2.566433 --> 2.497228). Saving model ... Epoch 3, Validation Loss: 2.595063328742981 EarlyStopping counter: 1 out of 5 Epoch 4, Validation Loss: 2.6587947010993958 EarlyStopping counter: 2 out of 5 Epoch 5, Validation Loss: 2.6591238975524902 EarlyStopping counter: 3 out of 5 Epoch 6, Validation Loss: 2.6106088757514954 EarlyStopping counter: 4 out of 5 Epoch 7, Validation Loss: 2.727466642856598 EarlyStopping counter: 5 out of 5 Early stopping...Conclusion
In this blog, we learned how to implement early stopping in Tensorflow, Keras, and Pytorch. and understood how early stopping helps us while training the model.
By using early stopping, we can significantly improve your model’s generalization by avoiding overfitting and potentially save on computational resources by stopping training once further improvements are unlikely.
You can also read our other blogs on how to get a data analyst internship in 2025 and OpenCV QR code tracking.