In a previous post, I introduced the theory behind the method of least squares and showed how it can be used to solve systems of equations with no unique solution.
Now, I want to look at one of its most practical applications: least squares fitting. In this tutorial, we'll perform straight-line fitting and polynomial least squares fitting, both by hand and with Python.
Before we look at some example problems, we need a little background and theory.
In least squares fitting, we have some function that takes -vectors as its inputs and maps them to real numbers. We don't really know anything about the function itself and what it does under the hood. It's your classic black box: You feed some vector to the function, and it spits out a in response.
Our goal in least squares fitting is to try to model as closely as possible, based on the input-output data pairs that we're given. Typically, we use the following notation for our data, with denoting the -th data pair:
Here, is the number of data points (i.e., the size of our data set), while is the size of each input vector, . Keep that in mind because these two are not necessarily the same.
Note: We use superscripts in parentheses to denote data pairs. Note that subscripts are usually reserved for the elements of a vector. So in this case, the first element of the second input vector would be represented as .
The typical example used in an introductory machine learning class is the house price index data set. You have data pairs of the form . You feed your feature vector to your function, and it produces some corresponding scalar value, , in response. In this case, may be a set of measurements for the home: the number of bedrooms, the number of bathrooms, its age, and so on. The corresponding output is , which denotes the price of the home—the real price, not a prediction.
Now, as I mentioned earlier, we rarely ever know what is. So what we'll do is model the relationship between each and as closely as we can. We approximate their relationship with a model function that we call :
Note that and are just placeholders. Since we're really given data points— and —we should write out the expanded form of the above by plugging in each data pair. That'll give us a clearer picture of what's going on:
This is starting to look more like a system of equations. But we're not quite there yet. How exactly do we pick ?
Summary: Our goal in data fitting is to model the relationship between the inputs, , and the outputs, , as closely as possible using a model function, . Because remember, we don't know the true relationship, .
Below is the general form of the model function used in least squares fitting:
If this seems confusing, don't worry—it's easier than it looks.
First, notice that the model is a function like any other. In this case, though, it's composed of parameters, , as well as functions, . In data fitting, these functions are called basis functions.
Okay, so how can we make this least squares model function more concrete?
Well, the good news is that we get to pick the basis functions based on how we think the real function, , behaves. As we'll see shortly, if appears to be linear in behavior, then we may decide to pick our basis functions such that ends up resembling a straight line. On the other hand, if appears to be quadratic, then we may pick our basis functions such that ends up being some sort of a polynomial.
Takeaway: We pick the basis functions based on how we think behaves. This is a key step in engineering a model function. In picking the basis functions, we also decide how many of them we'll need. This is .
We pick the basis functions. The values—the model parameters—are what we need to solve for.
Let's plug this general form of into the earlier set of equations that we saw:
Now that's more like it—this is a linear system of equations! Let's represent it in matrix form:
Notice that our matrix has dimensions . In practice, is often much, much larger than . Sound familiar? All this really means is that we have an overdetermined system—there's no exact solution to . This means that many data fitting problems are actually least squares problems—we need to find the that gets us as close as possible to .
- There exists some unknown relationship, , between and , such that .
- We approximate using .
- We pick the basis functions based on how we think the real function behaves.
- We solve for the parameters of our model——using the least squares method.
Here's a five-step strategy you can use to solve least squares problems:
- Visualize the problem. For example, you may be given a set of data points that you can plot.
- Pick an appropriate model. Based on what we learned, this involves choosing the basis functions and .
- Identify the equations involved. Write them out explicitly based on your input and output pairs.
- Solve the overdetermined system using the least squares method.
- (Optional) Visualize the solution. This is a useful way to sanity check your answer, though it's not fool-proof.
Note: For all the examples that follow, we'll let . That is, our input s will just be scalar values. In reality, this changes nothing about the least squares method.
Suppose we're given these data points for a least squares line fitting problem:
We're asked to model the relationship between and . Let's take it step by step.
First, we'll plot the points:
We note that the points, while scattered, appear to have a linear pattern. Clearly, it's not possible to fit an actual straight line to the points, so we'll do our best to get as close as possible—using least squares, of course.
We know appears linear, like a equation. We want our model function to look something like this:
Note: Alternatively, you could just as well pick . It won't change the solution.
So, we revisit our general model:
And we pick our basis functions, as promised, to give a linear shape. We pick to be , such that we have:
Next, we define our basis functions:
What does this do for us? Let's plug them into the general formula:
That gives us precisely the function we wanted.
Note: You don't have to be this explicit about how you select your basis functions. However, I recommend doing so because it allows you to verify that your reasoning is sound.
Here are all three equations for our problem:
Let's plug in our points:
And, in matrix form, this looks like the following:
Three equations and two unknowns—this is an overdetermined system. How do we solve this system? Well, as we know, there's no exact solution. But we can get the least squares solution by solving for in this equation:
Of course, we shouldn't solve this directly without first using QR decomposition. If you perform the necessary steps for QR decomposition, you'll get that:
You can verify this by performing matrix multiplication to see that you do in fact get back. It looks pretty nasty with all those square root terms, but they actually cancel out quite nicely as we'll see here in a second.
Let's plug into the least squares equation. Doing so yields the following simplified form:
Let's plug in the actual matrices:
And let's simplify the right-hand side of the equation:
This is a square system! Even better, it's an upper-triangular system—this means we can solve for really easily and then plug it back into the first equation to solve for (recall that this strategy is known as back-substitution). First, let's explicitly write out the two equations:
Solving the second equation, we get that .
Plug that into the first equation:
Solving yields , as desired.
We have the following solution:
Remember our model?
Plugging those in yields the following straight-line equation:
Let's plot the best-fit line along with the points:
Awesome! This is the best-line fit for the data points we were given.
As you add more points, data fitting (particularly the QR factorization portion) becomes more difficult to do by hand. Fortunately, you can use languages like MATLAB or Python to solve these problems. But now, when you do rely on computers, you'll at least know what they're doing behind the scenes.
Let's not stop there! Suppose instead that we are given these five data points:
Let's repeat the process.
Here's a graph of our points:
To me, these points seems to take on the shape of a parabola. Based on that observation, I'm going to perform a least squares polynomial fit using a polynomial of degree two (a quadratic, basically).
Since we're modeling a quadratic equation (degree-two polynomial), this is the general form of the model function we'll aim for:
To get that, we'll start with the original form again:
And we'll pick with:
There we go!
Note: Again, you could reverse the order of the polynomial to be . This would change nothing except the order of the elements in your resulting matrix.
Here are all five equations for our polynomial fitting problem:
Let's plug in the data we were given:
I'll simplify things a bit and represent this as a matrix equation:
Straight-line fitting is pretty simple by hand, but polynomial least squares fitting is where it gets kind of difficult. So I'm going to "cheat" and use Python! You can use MATLAB instead if you'd prefer; the language doesn't really matter once you know the theory.
Here's a script that uses QR factorization explicitly:
import numpy as np from numpy import linalg as LA # Our data A = np.array([[1, -4, 16], [1, 0, 0], [1, 1, 1], [1, 2, 4], [1, -6, 36]]) y = np.array([, , , , ]) # QR factorize A Q, R = LA.qr(A) # R (theta) = Q^T (y) QT = np.transpose(Q) theta = LA.solve(R, QT.dot(y)) print(theta)
However, this is really equivalent to the following code, which just uses the
import numpy as np from numpy import linalg as LA # Our data A = np.array([[1, -4, 16], [1, 0, 0], [1, 1, 1], [1, 2, 4], [1, -6, 36]]) y = np.array([, , , , ]) theta = LA.lstsq(A, y) print(theta)
Regardless of which version we run, we'll get the same answer for the vector:
Plugging this into our model, we arrive at the following polynomial function:
And here's the resulting graph with our polynomial fit to the data:
Looks like a pretty good fit to me!
That does it for this series on the least squares method. I hope you found this tutorial helpful!