What is overfitting? Why is it bad? And how can I avoid it?

Two enter, one remains…(Source: Wikipedia)

Pop quiz. In the above chart of eleven historical datapoints, which of the two lines in the picture above would be the most accurate for predicting the location of a new, twelfth datapoint? The blue line seems great — it’s fit perfectly to your known historical datapoints! The diagonal black line seems “dumb” in comparison, as it only manages to intersect three out of the eleven historical datapoints.

It may be surprising to some readers that the diagonal black line may predict the location of the next datapoint more accurately than the blue line.

Deciding how to choose between these models requires you to understand the concept of overfitting. Recognizing and combating overfitting is a core task in Machine Learning. There is no “magic bullet” approach to dealing with overfitting that will work in all use-cases, so it’s useful to gain some intuition.

How historical data creates predictive power

Every machine learning application has three components:

  1. historical data to train your algorithm,
  2. the algorithm, and
  3. the predicted values that the algorithm produces after training on the historical data.

A Machine Learning algorithm is a piece of code that handles a series of mathematical operations. These mathematical operations are governed by “parameters.” There are many kinds of parameters. Generally, you can think of parameters as instructions to process data in a certain way or to designate certain pieces of data as particularly important or unimportant.

Let’s look at a simple example. An algorithm that you might remember from high school math is the formula for the slope of a line: y = mx + b.

In the formula above, “y = mx + b” is an algorithm, and “m” and “b” are parameters. The algorithm defines how our estimate is produced. We can choose parameter values to approximate the historical data.

The idea is that if the parameterized algorithm can correctly approximate the past, it will do a good job at estimating the future.

Surprisingly, finding the parameterization for an algorithm that allows you to perfectly model the past is not actually what you want.

What is overfitting?

It may seem paradoxical, but getting an algorithm to perform perfectly on the historical data can make it useless for predicting the future. When a model fits the historical data better than it can make new estimates, that is called overfitting.

The University of Washington has a great video on YouTube that walks through overfitting. We’ll look at some of the slides they use in that talk.

Let’s take a look a some historical data plotted below. Our goal here is to guess where the data will end up on the far right side of the chart; does the trendline passing through the historical data put you at A, B, C, or D?

One line you could draw is a horizontal line that ends at B. This captures the average of the values, but it feels a bit off. This is because the last few data points are lower than the initial data points, so it seems that the data is trending downward.

Let’s adjust our algorithm somewhat so that its parameters allow it to account for the fact that the higher the salt concentration, the lower the yield. (For the purposes of this tutorial, it doesn’t matter what these measurements related to. They could be anything.)

When we do this, we see we get a downward sloping trendline that terminates at C.

That’s good, but it still doesn’t quite account for the fact that the 4th, 5th, and 6th data points are quite high. Let’s draw a curve to account for this.

Why is overfitting bad?

Alternatively, we could tell the algorithm to give extreme weight to the change in slope between the data points. Because the data farthest to the right flattens in slope compared to the previous data, an algorithm trained to be extremely sensitive to changes at the higher end of salt concentration might forecast an upward slope.

Finally, we can draw a curve that intersects every point of historical data, which might look like the chart below. This feels obviously wrong. It is also an extreme example of overfitting.

As we can see in the charts above, relying too much on the data available will create a model that is not general enough to detect the future. When parameterizing algorithms, the goal is to find a generalized fit on the historical data.

Let’s now add the missing data points to the chart to see how the different models would have performed on the real data.

The initial sloping line model actually performed pretty well for most of the data, but it fails to capture how the rate of change increase in the higher salt concentrations.

The curved line almost perfectly captures the actual data. If we remember back to when we were training it, we recall that we trained the curve to recognize that the values are more or less flat on average at low salt concentrations, but that they start to slope downward as the salt concentration increases.

This is an example of a model that is fit to the training data in a way that is general enough to capture the general trend into the unknown values.

How can I avoid overfitting?

Generally speaking, there are three types of data used in Machine Learning:

  • Training data
  • Validation data
  • Test data

(Note: there is a lack of consistency in the Machine Learning community about terms, so what you read here may appear to conflict with resources you encounter elsewhere.)

Generally speaking, the training data and validation data to construct a model — or, put another way, you use this data to parameterize an algorithm — then you use a new set of data that the model/algorithm has not seen before to look at its efficacy.

Training and validation are handled somewhat differently for time-series and classification problems. We’ll walk through the time-series case in a later tutorial as there is some additional complexity there. In the meantime, let’s use Monument (www.monument.ai) to take a closer look to look at panel data applications.

Training and validation can be hard to visualize for panel data (i.e. the type of data you have when you’re working on classification problems). This is because panel data is not “temporal” in the sense that there is no time dimension.

Time-series data is any kind of data where the same thing is measured repeatedly over regular time periods, like stock prices every minute or ocean temperature levels every month. Panel data is a general term for any kind of data that deals with attributes of entities, rather than measuring how something changes over time. Examples of panel data are, demographic and payment default information for credit card users and predicting future student test performance across several subjects based on historical records.

Classification and regression algorithms categorize data by their attributes rather than when in time the events recorded in the data occurred, so you would use these algorithms on panel data.

For this type of data, the training and validation process is called “cross-validation.” Cross-validation is harder to visualize how panel data is analyzed as this work happens behind the scenes in most Machine Learning platforms, including Monument.

Usually, algorithms for panel data take the historical data that you are starting with and break it up into equal parts, known as “folds.” Over several rounds of training, each fold is removed from the data and the other folds attempt to predict the “missing” fold, as shown in the graphic below. No-code tools like Monument handle this automatically.

How 4-fold cross-validation works.

The number of folds is set by a fold parameter. You can make as many as you want. If your data has 100 rows, you could set 100 folds. But, as we have been learning, it’s best to start with broad strokes. If increasing the number of folds improves the accuracy of the model — as measured by the validation error rate report, among other methods — then, great, keep adding folds. If not, keep it simple!

In Conclusion

Overfitting is a complex topic that is core to Machine Learning. There are many resources online about this important topic. As you navigate the real world, a general rule of thumb is that generalized models are more likely to be useful in predicting the future.

If you get similar validation error rates with a very simple model like a Linear Regression and a complex neural network like an LSTM, it’s probably better to use the Linear Regression. Advanced models should be used when the data is too noisy or the systems being analyzed are too complex to be modeled accurately with simple models.

Certain no-code Machine Learning platforms like Monument (www.monument.ai) handle a lot of this overfitting complexity for you. You don’t have to worry about handling the training and validation stages of your workflow as Monument does that automatically. And you can easily try a wide array of algorithms in a few seconds, to see which will provide the best results.

Interested in learning more about Monument? Book a free introductory Zoom call here.

Predictions to keep you two steps ahead.