How to implement early stopping in Tensorflow, Keras and Pytorch.

When experimenting with various parameters in a model like changing the number of layers, activation function, or neurons, sometimes the model doesn’t perform well while training.

We can save training time and other resources by stopping the model when its performance on the validation dataset deteriorates. The process is called early stopping in machine learning. It is a regularization technique that prevents overfitting while training a model based on its performance on a validation dataset.

Implementing early stopping is quite simple in popular deep-learning frameworks such as TensorFlow, Keras, and PyTorch. Below are examples of how to implement early stopping in each of these frameworks.

TensorFlow / Keras (using the EarlyStopping Callback)

In TensorFlow / Keras we have a built-in callback function called “EarlyStopping” that helps us to easily implement early stopping.

Let’s go through the code.

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping

# Load a sample dataset
(X_train, y_train), (X_val, y_val) = tf.keras.datasets.mnist.load_data()
X_train, X_val = X_train / 255.0, X_val / 255.0

# Define the model
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

# Compile the model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# Set up EarlyStopping callback
early_stopping = EarlyStopping(monitor='val_loss',   # Metric to monitor
                               patience=3,           # Number of epochs to wait before stopping
                               restore_best_weights=True)  # Restore the best model's weights

# Train the model with EarlyStopping callback
history = model.fit(X_train, y_train,
                    epochs=50,
                    validation_data=(X_val, y_val),
                    callbacks=[early_stopping])  # Pass the callback to the fit() function
Output
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
/usr/local/lib/python3.11/dist-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(**kwargs)
Epoch 1/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 16s 7ms/step - accuracy: 0.8604 - loss: 0.4732 - val_accuracy: 0.9582 - val_loss: 0.1351
Epoch 2/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 16s 4ms/step - accuracy: 0.9566 - loss: 0.1482 - val_accuracy: 0.9692 - val_loss: 0.1010
Epoch 3/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 9s 4ms/step - accuracy: 0.9684 - loss: 0.1054 - val_accuracy: 0.9742 - val_loss: 0.0841
Epoch 4/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 11s 4ms/step - accuracy: 0.9739 - loss: 0.0871 - val_accuracy: 0.9759 - val_loss: 0.0748
Epoch 5/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.9787 - loss: 0.0703 - val_accuracy: 0.9778 - val_loss: 0.0742
Epoch 6/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.9802 - loss: 0.0636 - val_accuracy: 0.9788 - val_loss: 0.0705
Epoch 7/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - accuracy: 0.9827 - loss: 0.0544 - val_accuracy: 0.9782 - val_loss: 0.0716
Epoch 8/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9840 - loss: 0.0518 - val_accuracy: 0.9787 - val_loss: 0.0695
Epoch 9/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 4ms/step - accuracy: 0.9845 - loss: 0.0460 - val_accuracy: 0.9790 - val_loss: 0.0694
Epoch 10/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9860 - loss: 0.0417 - val_accuracy: 0.9798 - val_loss: 0.0675
Epoch 11/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9887 - loss: 0.0349 - val_accuracy: 0.9809 - val_loss: 0.0672
Epoch 12/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - accuracy: 0.9882 - loss: 0.0364 - val_accuracy: 0.9794 - val_loss: 0.0707
Epoch 13/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 10s 3ms/step - accuracy: 0.9891 - loss: 0.0318 - val_accuracy: 0.9813 - val_loss: 0.0718
Epoch 14/50
1875/1875 ━━━━━━━━━━━━━━━━━━━━ 8s 4ms/step - accuracy: 0.9892 - loss: 0.0312 - val_accuracy: 0.9808 - val_loss: 0.0684
We can see model training stopped at the 14th epoch. If we carefully see val_loss from epoch 12 to 14, it has increased causing an early stop.

Explanation of Key Parameters of EarlyStopping function in TensorFlow / Keras

  • monitor: This metric is usuallyused to monitor ‘val_loss’ or ‘val_accuracy.
  • patience: It is the count of epochs to wait for improvement before stopping from training further.
  • restore_best_weights: Whether to restore the model weights from the best epoch

Implementing Early stopping in PyTorch

Using a manual method to implement early stopping in PyTorch

In Pytorch we don’t have a built-in early-stopping callback function like Keras. However, we can easily implement early stopping manually in PyTorch. The simple hack is to track the validation loss over the epochs and if it doesn’t improve after a certain number of epochs, we must stop the training

Let’s start coding.

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

# Sample dataset and model
X_train = np.random.rand(1000, 28 * 28).astype(np.float32)
y_train = np.random.randint(0, 10, 1000).astype(np.int64)
X_val = np.random.rand(200, 28 * 28).astype(np.float32)
y_val = np.random.randint(0, 10, 200).astype(np.int64)

train_data = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_data = TensorDataset(torch.tensor(X_val), torch.tensor(y_val))

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)

# Define a simple model
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#------------------early stopping implementation —---------------------

# Early stopping parameters
patience = 3
best_val_loss = float('inf')
epochs_without_improvement = 0

# Training loop
for epoch in range(50):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    # Validation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss}')

    # Early stopping condition
    if val_loss < best_val_loss: best_val_loss = val_loss epochs_without_improvement = 0 # Save the model checkpoint torch.save(model.state_dict(), 'best_model.pth') else: epochs_without_improvement += 1 if epochs_without_improvement >= patience:
        print("Early stopping...")
        break

Output

Epoch 1, Validation Loss: 2.345960259437561
Epoch 2, Validation Loss: 2.3512794971466064
Epoch 3, Validation Loss: 2.3335058093070984
Epoch 4, Validation Loss: 2.3214242458343506
Epoch 5, Validation Loss: 2.3485482335090637
Epoch 6, Validation Loss: 2.3570011854171753
Epoch 7, Validation Loss: 2.3309141993522644
Early stopping...

As we can see the model training was stopped as val_loss increased

Key Points:

  • Track val_loss or another metric for validation performance.
  • Stop training when the metric stops improving for a certain number of epochs (patience).
  • Save the best model weights during training and load them after stopping.

Using early-stopping-pytorch library for early stopping in Pytorch

We can use the “early-stopping-pytorch” library for implementing early stopping. But first we need to install it. Open your command prompt and type.
pip install early-stopping-pytorch
And hit enter. We will be using the model from the above code which we discussed.
from early_stopping_pytorch import EarlyStopping

# Initialize early stopping
early_stopping = EarlyStopping(patience=5, verbose=True) 

# Training loop
for epoch in range(50):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs).squeeze() #squeeze to remove extra dimension
        loss = criterion(outputs, labels.long())
        loss.backward()
        optimizer.step()
    
    # Validation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels.long())
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    print(f'Epoch {epoch+1}, Validation Loss: {val_loss}')

    # Apply early stopping
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping...")
        break
Output
Epoch 1, Validation Loss: 2.566432535648346
Validation loss decreased (inf --> 2.566433).  Saving model ...
Epoch 2, Validation Loss: 2.4972275495529175
Validation loss decreased (2.566433 --> 2.497228).  Saving model ...
Epoch 3, Validation Loss: 2.595063328742981
EarlyStopping counter: 1 out of 5
Epoch 4, Validation Loss: 2.6587947010993958
EarlyStopping counter: 2 out of 5
Epoch 5, Validation Loss: 2.6591238975524902
EarlyStopping counter: 3 out of 5
Epoch 6, Validation Loss: 2.6106088757514954
EarlyStopping counter: 4 out of 5
Epoch 7, Validation Loss: 2.727466642856598
EarlyStopping counter: 5 out of 5
Early stopping...
Conclusion

In this blog, we learned how to implement early stopping in Tensorflow, Keras, and Pytorch. and understood how early stopping helps us while training the model.

By using early stopping, we can significantly improve your model’s generalization by avoiding overfitting and potentially save on computational resources by stopping training once further improvements are unlikely.

You can also read our other blogs on how to get a data analyst internship in 2025 and OpenCV QR code tracking.