DEV Community

Cover image for Linear Regression in a Nutshell
VaishakhVipin
VaishakhVipin

Posted on

Linear Regression in a Nutshell

Everyday something exciting comes up in machine learning.

A new RL technique, a transformer architecture that is 0.001% more effective than GPT-2, synthetic data creation to train neural nets, and whatnot.

But before diving into all these things, we must fondly remember the simpler, time tested, arguably more efficient algorithm for less complex problems, ladies and gentlemen, I'M TALKING ABOUT NONE OTHER THAN

Yes, you heard me right, I'm going to show you the power of linear regression.

If I had to put it in a single sentence, linear regression is a machine learning model that tries to find a linear equation

f(x) = y = ax + b
Enter fullscreen mode Exit fullscreen mode

that fits our data the best.

It's not all sunshine and rainbows though, as we run into our first major issue. How do we define "fitting our data the best"?

What do we mean "fitting our data the best"?

We have a few ways in which we could approach this problem.
But before that, let us define our cost function.

Eᵢ = yᵢ - f(xᵢ)
Enter fullscreen mode Exit fullscreen mode

Great! With that out of the way, how do we optimize this cost function to accurately predict our data in the best way?

Method 1:

Σ i = 1-> n (Eᵢ)
Enter fullscreen mode Exit fullscreen mode

Whoa, don't run away just yet. Let me explain. All this is doing is minimizing the net variation of each data point in our dataset in comparison with the predicted linear equation.

If you have a sharp eye though, you would notice that high positive residues and high negative residues at various data points on addition can give a low resultant value, but that need not be accurate and could output multiple such lines.

Well, what else can we do?

Method 2:

Σ i = 1-> n (|Eᵢ|)
Enter fullscreen mode Exit fullscreen mode

You may have already caught this idea. If negative and positive values cancelled each other out, then just take the absolute value. If you look close enough, this too can return multiple such lines with a minimum of 2. If you don't want to take my word for it, check out with a custom dataset on a graphing software like Desmos.

Method 3:

Σ i = 1-> n (Eᵢ²)
Enter fullscreen mode Exit fullscreen mode

Beautiful! This is called the least squares criterion and fixes both our major issues of getting multiple lines and opposite signs cancelling each other out.

Knowing this, we can move on to the next step in our analysis of this algorithm.

How does the machine find out the equation now?

There are 2 main ways in which the equation for a linear regression model is found.

Firstly, we have the closed form equation:
If the dataset isn't massive, the slope and intercept of equation can be solved for almost in a single shot.

And secondly, gradient descent:
If the dataset is huge, the formula gets messy and solving for the necessary values becomes time-consuming and daunting. Instead of all that, we just let the computer walk downhill step by step. It looks at the slope of the error curve and keeps adjusting until it reaches the lowest point.

As much as I would love to explain the math further, it could get boring and could go out of the scope of this article. If you are interested, I could cover it in another article some time in the future.

Yeah, but this is a toy right? Real problems have SO MANY variables to account for

With one variable, we're fitting a line on a 2d graph, two variables it becomes a plane on a 3d graph, and as the variables increase, we can no longer visualize.

It is thus easier to illustrate with an example of housing costs in the hypothetical city of "Machineland" where we can see flying cars for transport and humanoids in the government.

Price = 100*Area + 7000*Bedrooms + 30000*Location + 25*No. of humanoids in 1km^2 range + Intercept
Enter fullscreen mode Exit fullscreen mode

Each coefficient thus tells you the impact of that feature, keeping others constant.

One limitation however is that the variables can sometimes be interdependent. Welcome to multicollinearity, which makes interpretation tricky.

Why should I care?

  • Interpretable (easy to explain results -> +2k per humanoid in 1 km^2)
  • Fast (Trivial to compute even on massive datasets)
  • Baseline (Every ML pipeline starts here)

Code demo

Run this script in your python IDE so that we can correlate pizza slice size to happiness using linear regression 🍕🥳

from sklearn.linear_model import LinearRegression
import numpy as np

# Pizza size in inches
X = np.array([[6], [8], [10], [12], [14]])
# Happiness rating out of 10 (totally made up!)
y = np.array([3, 5, 7, 9, 10])

# Train the model
model = LinearRegression()
model.fit(X, y)

print("Slope:", model.coef_[0])
print("Intercept:", model.intercept_)
print("Predicted happiness for a 16-inch pizza:", model.predict([[16]])[0])
Enter fullscreen mode Exit fullscreen mode

Your output should look like:

Slope: 0.7
Intercept: -1.5
Predicted happiness for a 16-inch pizza: 9.7
Enter fullscreen mode Exit fullscreen mode

Linear regression basically learns that the bigger the pizza, the happier the human!

Conclusion

I hope that you enjoyed my overview of this interesting topic. So before your next "wrapping an LLM and calling it an AI project", please put some respect to the cradle of machine learning, i.e. linear regression.

Thank you for reading!

Top comments (0)