DEV Community

Cover image for What is overfitting in machine learning?
Kedar Kodgire
Kedar Kodgire

Posted on • Originally published at educative.io

What is overfitting in machine learning?

Overfitting is a common challenge to overcome while training machine learning models.

To understand this concept, let’s consider a plane with an X and Y-axis. The X-axis represents the complexity of your model, and the Y-axis represents the loss function.

Alt Text

Explanation

In the diagram above, the blue line represents the training set and the orange line represents the validation set.

The training set

Let’s consider the blue line.

As the complexity of the model increases, the loss decreases, and vice versa.

The validation set

Now, let’s consider the orange line.

As we keep moving from left to right, the loss falls until a certain point. If you keep moving right from that point, the loss increases.

The point from which the loss increases is the minimal loss for the validation set.

Balanced model

Before this point, a model is considered under-fitted, and after this point, the model is deemed to be over-fitted. This level of complexity represents the balanced model.

To better understand the balanced model, let’s consider some data points to represent these three fittings.

Alt Text

Explanation

In the diagrams above, the red data points are the ones we use to train the model. This means they were already available.

The green data point is newly introduced, and we test it against our model.

If the green data point is introduced in the balanced machine learning model, it will be close to the plotted line, indicating greater accuracy.

Underfit model

In Fig. A, our ML model has plotted a straight line against the data points. The line crosses through a few data points, and other data points are further from the plotted line.

When a new data point is introduced, it is far from the line. This indicates that the model is not accurate.

Overfit model

The model is fitted exceptionally well, as seen in Fig. B because the plotted line passes through all the training data points.

This model is entirely accurate for the training data. However, introducing a new data point shows that the model may not work well for new data, as it is far away from the plotted line.

Balanced model

In Fig. C, which is our balanced model, the plotted curve passes through some of the data points. If we calculate the loss, this model would have the most minor loss out of all the models.

The introduction of a new point is near the plotted curve. Hence, when the model is used with actual/production data, our predictions will be reasonably accurate.

That's all for this post guys,
If you enjoyed it, don't forget to leave a like.
Happy Learning.

Top comments (0)