Linear regression explained
4 min read

Linear regression explained


This article is the first one in a long line of upcoming articles about all machine learning models. Subscribe to my newsletter if you want to get notified when they get published to improve your knowledge and skills in ML

But for now, let's focus on linear regression


If you're interested in the implementation of linear regression, click here.

What is linear regression?

Let's say you measured a couple of points, and plotted them on a graph (x, y). These values can represent anything you want (The cost of a house as a function of its livable surface area, the weight of a tree as a function of its height, etc...).

For my example, these are the values that I'm going to use:

(1.5, 2)
(2, 0.5)
(2.7, 3.3)
(3.6, 4.5)
(3.7, 4.7)
(4.5, 7.5)
(6.1, 6.5)
(6.4, 8.8)
(7.5, 9.6)
(8.5, 9.7)

When plotted, this is what they look like:

Now, let's ask ourselves: If we added a point, knowing its x-value, could we predict its y-value? For example, what could be an estimation of the y-value for x=5?

To answer this, we'll need to find a relation between y and x. Since it's linear, we'll be able to write (as you might remember from school) something that looks like the equation of a line: y = ax + b

In linear regression, we write: y = b0 + b1*x

The whole idea behind linear regression is to find those b0 and b1 values.
Using our example, we could take x=5 and see what y-value would be found.

For the curious ones out there, this is what a good fit (although not the best) might look like:

Using this model, if we plug in x=5, we get y=1.25*5 = 6.25

How does linear regression work?

To make linear regression work, we have to establish a relation between the points and the estimated line. This relation will be the "cost" (the error). With different existing cost functions, we'll choose one and try to minimize it. The smaller that value is, The best fit a line will be. We'll try to make it approach zero, but with real data, it's almost always impossible (except if the line is a perfect fit).

Here is an example of how different lines will yield different costs:

Notice that one of the lines fits better than all others, and this yields the lowest cost

To find the best fit, we'll have to simply find the best b_0 and b_1.

For the sake of simplicity, we'll use the y=ax+b notation and therefore call b_0 b and b_1 a.

The idea is to introduce a cost (cost(a, b)) function and try to minimize that function. We'll use the gradient descent method (which I'll soon write an article about) to change the values of a and b to minimize cost(a, b). We'll have to calculate the partial derivatives of the cost relative to a and b to be able to use the gradient descent method, so I'd recommend choosing a not-so-complicated cost function to start with.

This might be slightly confusing so let's actually get into an example:

For the cost function we'll choose:

Then we'll have to see in what way a and b affect the cost function. To do this we'll calculate both partial derivatives:

The result we got is actually the slope to increase the value of
cost(a, b) so we'll have to take its inverse, and add a learning rate to modify the values of a and b:

Now to minimize this function, we'll go over a certain number of epochs (iterations) repeating the same previous step. By the end, we should have optimized our a and b values to fit the points nicely.

To see a real example, click here.


Thanks a lot for reading until the end! I hope you learned something from this and be sure to subscribe to my newsletter to get notified when the next articles on the other machine learning models are published.

I'll also write an article about multiple linear regression at some point in the future.