DEV Community

Cover image for Breaking Down Cost Functions in Linear Regression: A Conceptual Overview
Awaliyatul Hikmah
Awaliyatul Hikmah

Posted on

Breaking Down Cost Functions in Linear Regression: A Conceptual Overview

The cost function is a crucial concept in machine learning, helping us understand how well our models are performing. It's the tool that tells us how close our model's predictions are to the actual results and guides us in improving accuracy. In this post, we'll break down the cost function in simple terms, with a focus on linear regression.

Introduction to Cost Functions

The cost function tells us how well our model's predictions match the actual target values. Essentially, it measures the error between the predicted values and the true values. By minimizing this error, we can improve our model's accuracy.

Consider you have a training set with input features xx and output targets yy .

Size in feet² (x) Price $1000s (y)
2104 460
1416 232
1534 315
852 178

The model you use to fit this training set can be represented by a linear function:

fw,b(x)=wx+bf_{w,b}(x) = w \cdot x + b

For a training example (x(i),y(i))(x^{(i)},y^{(i)} ) , the function ff predicts y(i)y^{(i)} as y^(i)\hat{y}^{(i)} . Thus:

y^(i)=fw,b(x(i))=wx(i)+b\hat{y}^{(i)} = f_{w,b}(x^{(i)}) = w \cdot x^{(i)} + b

The challenge is to find ww and bb that make the prediction y^(i)\hat{y}^{(i)} close to the target y(i)y^{(i)} for all training examples.

Here, ww and bb are the parameters of the model. These parameters are adjusted during training to enhance the model's performance.

different parameters, different line

Depending on the values chosen for ww and bb , we will get different functions f(x)f(x) , which generate different lines on a graph. Writing f(x)f(x) as shorthand for fw,b(x)f_{w,b}(x) , we can look at some plots to understand how ww and bb influence ff .

  • When w=0w = 0 and b=1.5b = 1.5
f(x)=0x+1.5f(x) = 0 \cdot x + 1.5

The function is a horizontal line, predicting a constant value of 1.5.

w=0 b=1.5

  • When w=0.5w = 0.5 and b=0b = 0
f(x)=0.5xf(x) = 0.5 \cdot x

The slope is 0.5, creating a line that increases steadily.

w=0.5 b=0

  • When w=0.5w = 0.5 and b=1b = 1
f(x)=0.5x+1f(x) = 0.5 \cdot x + 1

This line has a slope of 0.5 and intersects the vertical axis at 𝑏=1.

w=0.5 b=1

Understanding Errors

The cost function calculates the error between the predicted prices y^\hat{y} and the actual prices yy . This error is given by:

Error=y^y\text{Error} = \hat{y} - y

Then, we square this error to avoid negative values. This squaring operation ensures that all errors are positive and emphasizes larger errors more than smaller ones. It will be:

Error=(y^y)2\text{Error} = (\hat{y} - y)^2

Here, if we have more training examples, the sum of the errors will naturally be larger. To normalize this, we use the average squared error instead of the total squared error to get a sense of the overall performance. This way, the cost function doesn't automatically get bigger just because we have more training examples. It makes it a fair comparison, no matter how big our dataset is. Dividing by the number of examples mm :

Mean Squared Error =1mi=1m(y^(i)y(i))2= \frac{1}{m} \sum_{i=1}^{m} (\hat{y}^{(i)} - y^{(i)})^2

To simplifies the derivative calculations during optimization (like gradient descent), we add factor of 12\frac{1}{2} in the cost function. Thus, the final of cost function formula will be:

J(w,b)=12mi=1m(fw,b(x(i))y(i))2J(w, b) = \frac{1}{2m} \sum_{i=1}^{m} (f_{w,b}(x^{(i)}) - y^{(i)})^2

The extra division by 2 is a bit of a mathematical trick to make later calculations easier, especially when we use calculus to minimize the cost function.

Why Do We Square the Error?

Imagine we're trying to predict something—like the price of a house. Our model makes a prediction (y^)(\hat{y}) and we compare it to the actual price (y)(y) . The error is the difference between these two:

Error=y^y\text{Error} = \hat{y} - y

But here's the thing: this error can be positive or negative.

  • If your prediction is higher than the actual value, the error is positive.
  • If your prediction is lower than the actual value, the error is negative.

Example:

  • Predicted price (y^)(\hat{y}) : $300,000
  • Actual price (y)(y) : $280,000

The error is 300,000 − 280,000 = 20,000 (positive error)

But:

  • Predicted price (y^)(\hat{y}) : $250,000
  • Actual price (y)(y) : $280,000

The error is 250,000 − 280,000 = −30,000 (negative error)

If we simply add up these errors, positive and negative values can cancel each other out, which wouldn't give us the real picture of how well our model is performing.

This helps us understand the real performance of our model and work on making better predictions.

Why Not Use Absolute Value to Avoid Negative Value?

When we square the errors, larger errors have a bigger impact. For example, an error of 10 becomes 100 when squared, while an error of 1 becomes 1. This helps the model focus on reducing larger mistakes more aggressively. The squared error function also smooth and differentiable everywhere. This smoothness is important for optimization algorithms like gradient descent because it allows for more efficient and predictable convergence to the minimum error.

Meanwhile, when we use absolute error, each error contributes linearly. An error of 10 remains 10, and an error of 1 remains 1. Both are treated equally without any extra emphasis on the larger ones. The absolute value function has a kink at zero, meaning it's not differentiable at that point. This can complicate the optimization process, making it harder to find the minimum error.

In summary, squaring the error puts more emphasis on larger mistakes, which helps in creating a better model overall by addressing those big errors more effectively. This is why squared errors are often preferred in many machine learning applications.

Why Divide by 2m2m and Not 100m100m ?

You may wonder, if our goal is to avoid larger numbers in the cost function as our data set gets bigger, why should we divide it by 2m? Why 2? Why not 100 or some other number? As I mentioned before, the extra division by 2 is a bit of a mathematical trick to make later calculations easier. Specifically, The factor of 2 cancels out when we take the derivative, simplifying our calculations.

When training the model, we often use optimization algorithms like gradient descent to minimize the cost function. Gradient descent involves taking the derivative (gradient) of the cost function with respect to the parameters ww and bb . The gradient tells us how to change the parameters to reduce the cost.

Consider a simple function:

J(w)=1mi=1m(wx(i)y(i))2J(w) = \frac{1}{m} \sum_{i=1}^{m} (wx^{(i)} - y^{(i)})^2

When we take the derivative of this function with respect to ww :

dJdw=1mi=1m2(wx(i)y(i))x(i)\frac{dJ}{dw} = \frac{1}{m} \sum_{i=1}^{m} 2(wx^{(i)} - y^{(i)})x^{(i)}

The derivative produces a 2 from the squared term. This 2 can make the gradient calculations a bit cumbersome. By including a factor of 1/2 in the cost function, we simplify the gradient calculations. This adjustment doesn't change the ultimate goal (minimizing the cost), but it makes the math cleaner:

J(w)=12mi=1m(wx(i)y(i))2J(w) = \frac{1}{2m} \sum_{i=1}^{m} (wx^{(i)} - y{(i)})^2

Now, when we take the derivative, the factor of 2 cancels out:

dJdw=1mi=1m(w(x(i))y(i))x(i)\frac{dJ}{dw}= \frac{1}{m} \sum_{i=1}^{m} (w(x^{(i)}) - y^{(i)})x^{(i)}

Why Not Another Number?

If we used a different number, like 100, in the cost function, the math wouldn't simplify as neatly:

J(w)=1100mi=1m(wx(i)y(i))2J(w) = \frac{1}{100m} \sum_{i=1}^{m} (wx^{(i)} - y{(i)})^2

Taking the derivative:

dJdw=150mi=1m(w(x(i))y(i))x(i)\frac{dJ}{dw}= \frac{1}{50m} \sum_{i=1}^{m} (w(x^{(i)}) - y^{(i)})x^{(i)}

Here, the factor 1/50 doesn't simplify as nicely, and we end up with more complex expressions. This also means your model will learn very slowly. You would have to compensate by increasing the learning rate 𝛼, but this requires careful tuning to avoid making the model unstable.

Using 12m\frac{1}{2m} is a balanced choice. It simplifies the gradient calculations without making the steps too small or too large. It’s also a widely-accepted convention, which makes it easier to follow standard practices and compare results across different studies and implementations.

Visualizing the Cost Function

In linear regression, the objective is to find the optimal values for the parameters ww and bb that minimize the cost function J(w,b)J(w,b) . This is typically achieved through an optimization algorithm, such as gradient descent, which iteratively adjusts ww and bb to reduce the difference between the predicted outputs and the actual target values.

To illustrate this concept, let's work with a simplified version of the linear regression model:

fw(x)=wxf_w(x) = w \cdot x

In this model, we've eliminated the parameter bb . Now, the cost function looks like this:

J(w)=12mi=1m(wx(i)y(i))2J(w) = \frac{1}{2m} \sum_{i=1}^{m} (wx^{(i)} - y^{(i)})^2

The goal is to find the value of ww that minimizes J(w)J(w) . Let's visualize how the cost function changes with different values of ww .

The graphs below show both the function fw(x)f_w(x) (left) and the corresponding cost function J(w)J(w) (right) for four different values of ww :

  1. When w=1w = 1 : The function fw(x)f_w(x) is a line with a slope of 1, and the cost J(w)J(w) is 0 for perfectly fitting data points and resulting in a lower cost.

    • Function: f(x)=1xf(x) = 1 \cdot x
    • Graph: Graph for w=1
  2. When w=0.5w = 0.5 : The function fw(x)f_w(x) has a slope of 0.5, leading to a higher cost due to the error between predicted and actual values.The line does not fit the data well.

    • Function: f(x)=0.5xf(x) = 0.5 \cdot x
    • Graph: Graph for w=0.5
  3. When w=0w = 0 : The function fw(x)f_w(x) is a horizontal line, resulting in a significant error and a higher cost.

    • Function: f(x)=0xf(x) = 0 \cdot x
    • Graph: Graph for w=0 The line is a flat line, which does not fit the data points at all, which also results in a high cost.
  4. When w=0.5w = -0.5 : The function fw(x)f_w(x) is a line with slopes downwards, showing an inverse relationship with the data points, resulting in the highest cost.

    • Function: f(x)=0.5xf(x) = -0.5 \cdot x
    • Graph: Graph for w=-0.5


Here's how the plots would look:

Cost Function Graph for all w

These graphs help visualize how different values of ww affect the line that fits the data points and the corresponding value of the cost function J(w)J(w) . The goal is to find the value of ww that results in the lowest cost, indicating the best fit for the data.

Choosing the Optimal Parameters

The goal is to choose the value of ww that minimizes J(w)J(w) . This is achieved by selecting the value of ww that results in the smallest possible value of the cost function. For instance, in our example, if choosing w=1w = 1 results in the smallest J(w)J(w) , then w=1w = 1 is the optimal parameter for our model.

Different Cost Functions for Different Applications

While the mean squared error cost function is the most commonly used for linear regression, different applications may require different cost functions. The mean squared error is popular because it generally provides good results for many regression problems.

Conclusion

The cost function is a fundamental concept in machine learning that helps us measure how well our model's predictions align with the actual target values. It guides the optimization process to improve the model's accuracy. Understanding and utilizing the cost function effectively can significantly enhance the performance of your machine learning models.

Top comments (0)