Overfitting (or overtraining) is a common problem for supervised learning models in which learned behavior from a training dataset does not generalize well to an unseen test dataset. The most common cause of overfitting is model complexity (in random forests an example would be using trees with too much depth.) The good news is that overfitting is easily testable and remedied. This paper will describe an approach to testing overfitting using the probability distributions of binary classifier output and the Kolmogorov-Smirnov test.
Some Context for Overfitting Tests
In order to fully understand the overfitting tests in this paper, I first re-establish a few fundamental principles of statistical learning that are relevant to this discussion. When one trains a supervised learning model and provides its output (weights, insights, etc.) as a product, they are making a few claims:
- The model has learned correlations/relationships between the input features themselves as well as the target.
- These learned correlations will generalize to un- seen data such that predictions can be made and understood as arising from the learned feature correlations.
When a model overfits to the training data these two claims can not be made and we are left with a model that cannot be trusted to generalize to new data, nor can any claims be made to understand why the model predicted what it did.
Kolmogorov-Smirnov Test Statistic
A straightforward way to test a binary classifier for overfitting is to plot the classifier output (a probability output from zero to one) for both the test and train sets (see Figure 1).
If the two claims we made about our statistical learning model from section one are to hold then we need to require that the test and train distributions of this plot are consistent with one another. The Kolmogorov- Smirnov (KS) test will do just that. For a robust mathematical description one can read here. The KS test can be framed as a non-parametric hypothesis test of agreement between two histograms. In this case the KS statistic is defined as
F1 and F2 are the cumulative distribution functions of n and m sample histograms respectively and sup is the supremum function.
When the KS statistic takes low values (near zero) the p-value becomes large and we can not reject the null hypothesis of the two distributions coming from an underlying, common distribution (exactly what we wanted to prove in order to establish our two assertions from section one.)
In order to drive home the power of the KS test, I have included a test case of random forest that is purposely overfitted by making it overly complex (tree depth =
It is clear from visual inspection of these plots that one of these models is producing output on the test set that does not represent the correlations learned on the training set. The KS test statistic encapsulates what we can verify visually. One very important side note here is that the accuracy/precision/recall for the overfitted model can be HIGHER than that of the non-overfitted model. This is due to all of those metrics relying on a probability threshold choice that does not take into account whether the test/train distributions are consistent.
Conclusion & Code
The take-home: make these plots and check your binary classification model for overfitting or else you are at risk of shipping a model that will not generalize to new data and should not be put into a production pipeline.
Below is the function that we use to create a KS plot. You will also see an example of how to use the function using example data provided by sklearn.
def make_ks_plot(y_train, train_proba, y_test, test_proba, bins=30, fig_sz=(10, 8)): ''' OUTPUT: outputs KS test/train overtraining plots for classifier output INPUTS: y_train - Series with outputs of model train_proba - np.ndarray from sklearn predict_praba(). Same shape as y_train. 0-1 probabilities from model. y_test - Series with outputs of model test_proba - np.ndarray from sklearn predict_praba(). Same shape as y_test. 0-1 probabilities from model. bins - number of bins for viz. Default 30. label_col_name - name of y-label. Change to whatever your model has it named. Default 'label'. fig_sz - change to True in order to get larger outputs. Default False. ''' train = pd.DataFrame(y_train, columns=["label"]) test = pd.DataFrame(y_test, columns=["label"]) train["probability"] = train_proba test["probability"] = test_proba decisions =  for df in [train, test]: d1 = df['probability'][df["label"] == 1] d2 = df['probability'][df["label"] == 0] decisions += [d1, d2] low = min(np.min(d) for d in decisions) high = max(np.max(d) for d in decisions) low_high = (low,high) fig = plt.figure(figsize=fig_sz) train_pos = plt.hist(decisions, color='r', alpha=0.5, range=low_high, bins=bins, histtype='stepfilled', density=True, label='+ (train)') train_neg = plt.hist(decisions, color='b', alpha=0.5, range=low_high, bins=bins, histtype='stepfilled', density=True, label='- (train)') hist, bins = np.histogram(decisions, bins=bins, range=low_high, density=True) scale = len(decisions) / sum(hist) err = np.sqrt(hist * scale) / scale width = (bins - bins) center = (bins[:-1] + bins[1:]) / 2 test_pos = plt.errorbar(center, hist, yerr=err, fmt='o', c='r', label='+ (test)') hist, bins = np.histogram(decisions, bins=bins, range=low_high, density=True) scale = len(decisions) / sum(hist) err = np.sqrt(hist * scale) / scale test_neg = plt.errorbar(center, hist, yerr=err, fmt='o', c='b', label='- (test)') # get the KS score ks = stats.ks_2samp(decisions, decisions) plt.xlabel("Classifier Output", fontsize=12) plt.ylabel("Arbitrary Normalized Units", fontsize=12) plt.xlim(0, 1) plt.plot(, , ' ', label='KS Statistic (p-value) :'+str(round(ks,2))+'('+str(round(ks,2))+')') plt.legend(loc='best', fontsize=12) plt.show() plt.close() # While not a beautiful example of the power of KS, # here is an example of the KS viz in action from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split from sklearn.datasets import load_iris data = load_iris() X = data.data y = data.target X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42) clf = RandomForestClassifier(n_estimators=100, max_depth=8, random_state=42) clf.fit(X_train, y_train) train_proba = clf.predict(X_train) test_proba = clf.predict(X_test) make_ks_plot(y_train, train_proba, y_test, test_proba)