You’ve built your classifier, run cross-validation and have a super high AUC. So you are done right? Maybe not.
Most classifiers output a score of how likely an observation is to be in the positive class. Usually these scores are between 0 and 1 and get called probabilities. However these probabilities often do not reflect reality, e.g. a probability of 20% may not mean it has a 20% chance of happening, it could have a 10%, 50%, 70%, etc. chance of happening. Our aim should be that our model outputs accurately reflect posterior probabilities \(P(Y=1|x)\).
In the post we will mainly focus on binary classifiers. Later in the post will will talk about how to extend these ideas to mutliclass problems.
Why it happens
Our models can output inaccurate probabilities for a variety of reasons:
- Flawed model assumptions (e.g. independence in a Naive Bayes Model)
- Hidden features not available at training time
- Deficiencies in the learning algorithm.
In terms of learning algorithms, as noted in Niculescu-Mizil et al and through my own research:
- Naive Bayes tends to push probabilities toward the extremes of 0 and 1.
- SVMs and boosted trees tend to push probabilities away from 0 and 1 (toward the centre)
- Neural networks and random forests tend to have well calibrated probabilities.
- Regularisation also tends to push probabilities toward the centre.
Do we care?
Whether or not we want well calibrated probabilities depends entirely on the problem we are trying to solve.
If we only need to rank observations from most likely to least likely, then calibration is unnecessary.
Examples of problems I have worked on where calibrated probabilities are extremely important:
- Loan default prediction – Banks will generally be setting thresholds on the probabilities, auto-reject if probability of default is above 30%, etc.
- Ad Click Prediction – Decided what ad to show, how much to bid. You might use a baseline Click Through Rate (CTR), and compare your prediction to this to see how much more you are willing to pay for this ad impression. 1
- Demographics Estimation of Websites – Imagine you have predictions of gender/ages, as probabilities, for a number of web users. Estimating the gender distribution on a website, can be done by just averaging the probabilities of the users seen on the site. Any bias in the probabilities, will generate a bias in your estimate.
Visualisation
A reliability diagram is a relatively simple technique for visualising how well calibrated our classifier is. As described in Niculescu-Mizil et al:
On real problems where the true conditional probabilities are not known, model calibration can be visualized with reliability diagrams (DeGroot & Fienberg, 1982). First, the prediction space is discretized into ten bins. Cases with predicted value between 0 and 0.1 fall in the first bin, between 0.1 and 0.2 in the second bin, etc.
For each bin, the mean predicted value is plotted against the true fraction of positive cases. If the model is well calibrated the points will fall near the diagonal line.
Below we provide a piece of R code for producing relibability diagrams. Here we generalise the number of bins to be a user defined parameter.
123456789101112131415161718192021222324252627282930313233 reliability.plot <- function(obs, pred, bins=10, scale=T) {# Plots a reliability chart and histogram of a set of predicitons from a classifier## Args:# obs: Vector of true labels. Should be binary (0 or 1)# pred: Vector of predictions of each observation from the classifier. Should be real# number# bins: The number of bins to use in the reliability plot# scale: Scale the pred to be between 0 and 1 before creating reliability plotrequire(plyr)library(Hmisc)min.pred <- min(pred)max.pred <- max(pred)min.max.diff <- max.pred - min.predif (scale) {pred <- (pred - min.pred) / min.max.diff}bin.pred <- cut(pred, bins)k <- ldply(levels(bin.pred), function(x) {idx <- x == bin.predc(sum(obs[idx]) / length(obs[idx]), mean(pred[idx]))})is.nan.idx <- !is.nan(k$V2)k <- k[is.nan.idx,]plot(k$V2, k$V1, xlim=c(0,1), ylim=c(0,1), xlab="Mean Prediction", ylab="Observed Fraction", col="red", type="o", main="Reliability Plot")lines(c(0,1),c(0,1), col="grey")subplot(hist(pred, xlab="", ylab="", main="", xlim=c(0,1), col="blue"), grconvertX(c(.8, 1), "npc"), grconvertY(c(0.08, .25), "npc"))}
In the figure below we show an example reliability plot. Ideally the reliability plot (red line) should be as close to the diagonal line as possible. As there is significant deviation from the diagonal, calibrating the probabilities will possible help.
It is also worth mentioning that if you take the mean of the score distribution, it should ideally be close to the prior.
Techniques for calibration
Overfitting
The most important step is to create a separate dataset to perform calibration with. Our steps for calibration are:
- Split dataset into test and train
- Split the train set into model training and calibration.
- Train the model on train set
- Score test and calibration set
- Train the calibration model on calibration set
- Score the test set using calibration
How much data to use for calibration will depend on the amount of data you have available. The calibration model will generally only be fitting a small number of parameters (so you do not need a huge volume of data). I would aim for around 10% of your training data, but at a minimum of at least 50 examples.
Platt Scaling
Platt Scaling essentially involves fitting a logistic regression on the classifier output. Originally developed to fit probabilities to the outputs of SVM 2, it is also well suited to the output of most other classifiers.
1 2 3 4 5 6 |
calib.data.frame <- data.frame(cbind(Y.calib, Y.calib.pred)) colnames(calib.data.frame) <- c("y", "x") calib.model <- glm(y ~ x, calib.data.frame, family=binomial) calib.data.frame <- data.frame(Y.test.pred) colnames(calib.data.frame) <- c("x") Y.test.pred.calibrated <- predict(calib.model, newdata=calib.data.frame, type="response") |
The reliability diagram below shows the original reliability plot (green) and after Platt Scaling (red).
The Platt Scalding should not change the rank of the observations, so measures such as AUC will be unaffected. However, measures like Log Loss 3 will be improved. In this example, Log Loss was originally 0.422 and improved to 0.418.
In Platt’s original paper suggests , instead of using the original {0,1} targets in the calibration sample, it suggests to mapping to:
$$t_+=\frac{N_+ + 1}{N_+ + 2}$$
$$t_-=\frac{1}{N_-+2}$$
where \(N_+\) and \(N_-\) are the number of positive and negative examples in the calibration sample.
To some extent this introduces a level of regularisation. Imagine if you only gave probabilities of either 0 or 1 and you correctly predicted all examples. Your Log Loss would be zero. With Platt’s transformation, you Log Loss would be non-zero. As the Log Loss is what you are optimising when fitting the logistic regression, a level of regularisation is introduced.
In my experiments, this transformation had little to no effect on the reliability diagram and Log Loss, so seems an unnecessary step. It may be useful if you have very few examples and overfitting is more of a concern (therefore regularisation would help). You could also use a ridge or lasso regression.
Isotonic Regression
With Isotonic Regression you make the assumption:
$$y_i = m(f_i) + \epsilon_{i}$$
where \(m\) is an isotonic (monotonically increasing or decreasing) function. This is the exact same assumptions we would use for least squares, except \(m\) is now a isotonic function instead of linear.
Below is an R example of how to perform isotonic regression using the isoreg function.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
fit.isoreg <- function(iso, x0) { o = iso$o if (is.null(o)) o = 1:length(x) x = iso$x[o] y = iso$yf ind = cut(x0, breaks = x, labels = FALSE, include.lowest = TRUE) min.x <- min(x) max.x <- max(x) adjusted.knots <- iso$iKnots[c(1, which(iso$yf[iso$iKnots] > 0))] fits = sapply(seq(along = x0), function(i) { j = ind[i] # Handles the case where unseen data is outside range of the training data if (is.na(j)) { if (x0[i] > max.x) j <- length(x) else if (x0[i] < min.x) j <- 1 } # Find the upper and lower parts of the step upper.step.n <- min(which(adjusted.knots > j)) upper.step <- adjusted.knots[upper.step.n] lower.step <- ifelse(upper.step.n==1, 1, adjusted.knots[upper.step.n -1] ) # Pefrom a liner interpolation between the start and end of the step denom <- x[upper.step] - x[lower.step] denom <- ifelse(denom == 0, 1, denom) val <- y[lower.step] + (y[upper.step] - y[lower.step]) * (x0[i] - x[lower.step]) / (denom) # Ensure we bound the probabilities to [0, 1] val <- ifelse(val > 1, max.x, val) val <- ifelse(val < 0, min.x, val) val <- ifelse(is.na(val), max.x, val) # Bit of a hack, NA when at right extreme of distribution val }) fits } # Remove any duplicates idx <- duplicated(Y.calib.pred) Y.calib.pred.unique <- Y.calib.pred[!idx] Y.calib.unique <- Y.calib[!idx] iso.model <- isoreg(Y.calib.pred.unique, Y.calib.unique) Y.test.pred.calibrated <- fit.isoreg(iso.model, Y.test.pred) |
In the figure below we show an example of the sort of function fitted by the isotonic regression model:Notice how it goes up in steps instead of a smooth curve. To smooth the fit, we perform a linear interpolation between each step.
In the reliability plot above, the original uncalibarated scores are shown in green and the isotonic regression scores are shown in red. In this example we find isotonic regression actually made it worse. The Log Loss for instance went from 0.422 to 0.426. The AUC was also reduced.
Multiclass Classification
What happens if you have more than two classes? Firstly I would recommend visualising the problem as a series of reliability diagrams. For k classes, you can create k reliability diagrams.
Secondly, you can take the score for each of your k classes and plug them into a multinomial logistic regression. The superb glmnet package implements a multinomial logistic regression. You can set the regularisation parameter to something quite small. One word of caution, if you have many classes, overfitting can become an issue. At this point it is worth optimising the regularisation parameter.
If your favourite machine learning model (e.g. SVM) doesn’t directly support multi-class classification, you can fit a 1 vs. all set of classifiers and then plug each of those scores into the multinomial logistic regression.
Summary
Classifier probability calibration can be an important step in your machine learning pipeline. The first step is to always visualise and see how much of an issue you have. In general I have found Platt Scaling to be the simplest and most effective approach to most calibration of classification problems.
- See Google’s paper on issues with ad click prediction: http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf ↩
- See Platt’s original paper: http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.41.1639 ↩
- Log loss definition: https://www.kaggle.com/wiki/LogarithmicLoss ↩