caretEnsemble is a package for making ensembles of caret
models. You should already be somewhat familiar with the caret package
before trying out caretEnsemble.
caretEnsemble has 3 primary functions:
caretList,
caretEnsemble and
caretStack. caretList is used
to build lists of caret models on the same training data, with the same
re-sampling parameters. caretEnsemble and
caretStack are used to create ensemble models from such
lists of caret models. caretEnsemble uses a glm to create a
simple linear blend of models and caretStack uses a caret
model to combine the outputs from several component caret models.
caretList is a flexible function for fitting many
different caret models, with the same resampling parameters, to the same
dataset. It returns a convenient list of caret objects
which can later be passed to caretEnsemble and
caretStack. caretList has almost exactly the
same arguments as train (from the caret package), with the
exception that the trControl argument comes last. It can
handle both the formula interface and the explicit x,
y interface to train. As in caret, the formula interface
introduces some overhead and the x, y
interface is preferred.
caretEnsemble has 2 arguments that can be used to
specify which models to fit: methodList and
tuneList. methodList is a simple character
vector of methods that will be fit with the default train
parameters, while tuneList can be used to customize the
call to each component model and will be discussed in more detail later.
First, lets build an example dataset (adapted from the caret
vignette):
#Adapted from the caret vignette
library("caret")
library("mlbench")
library("pROC")
data(Sonar)
set.seed(107)
inTrain <- createDataPartition(y = Sonar$Class, p = .75, list = FALSE)
training <- Sonar[inTrain, ]
testing <- Sonar[-inTrain, ]
my_control <- trainControl(
method="boot",
number=25,
savePredictions="final",
classProbs=TRUE,
index=createResample(training$Class, 25),
summaryFunction=twoClassSummary
)Notice that we are explicitly setting the resampling
index to being used in trainControl. If you do
not set this index manually, caretList will attempt to set
it for automatically, but it”s generally a good idea to set it
yourself.
Now we can use caretList to fit a series of models (each
with the same trControl):
library("rpart")
library("caretEnsemble")##
## Attaching package: 'caretEnsemble'
## The following object is masked from 'package:ggplot2':
##
## autoplot
model_list <- caretList(
Class~., data=training,
trControl=my_control,
methodList=c("glm", "rpart")
)(As with train, the formula interface is convenient but
introduces move overhead. For large datasets the explicitly passing
x and y is preferred). We can use the
predict function to extract predictions from this object
for new data:
p <- as.data.frame(predict(model_list, newdata=head(testing)))
print(p)| glm | rpart |
|---|---|
| 0.00e+00 | 0.8750000 |
| 1.00e+00 | 0.8750000 |
| 1.00e+00 | 0.1818182 |
| 2.35e-05 | 0.1818182 |
| 0.00e+00 | 0.1818182 |
| 1.00e+00 | 0.1818182 |
If you desire more control over the model fit, use the
caretModelSpec to construct a list of model specifications
for the tuneList argument. This argument can be used to fit
several different variants of the same model, and can also be used to
pass arguments through train down to the component
functions (e.g. trace=FALSE for nnet):
library("mlbench")
library("randomForest")
library("nnet")
model_list_big <- caretList(
Class~., data=training,
trControl=my_control,
metric="ROC",
methodList=c("glm", "rpart"),
tuneList=list(
rf1=caretModelSpec(method="rf", tuneGrid=data.frame(.mtry=2)),
rf2=caretModelSpec(method="rf", tuneGrid=data.frame(.mtry=10), preProcess="pca"),
nn=caretModelSpec(method="nnet", tuneLength=2, trace=FALSE)
)
)Finally, you should note that caretList does not support
custom caret models. Fitting those models are beyond the scope of this
vignette, but if you do so, you can manually add them to the model list
(e.g. model_list_big[["my_custom_model"]] <- my_custom_model).
Just be sure to use the same re-sampling indexes in
trControl as you use in the caretList
models!
caretList is the preferred way to construct list of
caret models in this package, as it will ensure the resampling indexes
are identical across all models. Lets take a closer look at our list of
models:
xyplot(resamples(model_list))As you can see from this plot, these 2 models are uncorrelated, and the rpart model is occasionally anti-predictive, with a one re-sample showing AUC of 0.46.
We can confirm the 2 model”s correlation with the
modelCor function from caret (caret has a lot of convenient
functions for analyzing lists of models):
modelCor(resamples(model_list))## glm rpart
## glm 1.00000000 -0.01658742
## rpart -0.01658742 1.00000000
These 2 models make a good candidate for an ensemble: their predictions are fairly uncorrelated, but their overall accuracy is similar. We do a simple, linear greedy optimization on AUC using caretEnsemble:
greedy_ensemble <- caretEnsemble(
model_list,
metric="ROC",
trControl=trainControl(
number=2,
summaryFunction=twoClassSummary,
classProbs=TRUE
))
summary(greedy_ensemble)## The following models were ensembled: glm, rpart
## They were weighted:
## 1.3764 -1.0734 -1.8211
## The resulting ROC is: 0.7379
## The fit for each individual model on the ROC is:
## method ROC ROCSD
## glm 0.6884429 0.08365065
## rpart 0.7013261 0.05304238
The ensemble”s AUC on the training set resamples is 0.76, which is about 7% better than the best individual model. We can confirm this finding on the test set:
library("caTools")
model_preds <- lapply(model_list, predict, newdata=testing, type="prob")
model_preds <- lapply(model_preds, function(x) x[, "M"])
model_preds <- data.frame(model_preds)
ens_preds <- predict(greedy_ensemble, newdata=testing, type="prob")
model_preds$ensemble <- ens_preds
caTools::colAUC(model_preds, testing$Class)## glm rpart ensemble
## M vs. R 0.7137346 0.7746914 0.8140432
The ensemble”s AUC on the test set is about 6% higher than the best individual model.
We can also use varImp to extract the variable importances from each member of the ensemble, as well as the final ensemble model:
varImp(greedy_ensemble)| overall | glm | rpart | |
|---|---|---|---|
| V4 | 0.0000000 | 0.0000000 | 0.000000 |
| V59 | 0.0289819 | 0.0781531 | 0.000000 |
| V41 | 0.0770670 | 0.2078205 | 0.000000 |
| V53 | 0.0913891 | 0.2464417 | 0.000000 |
| V56 | 0.1204130 | 0.3247080 | 0.000000 |
| V38 | 0.1247390 | 0.3363738 | 0.000000 |
| V8 | 0.1311981 | 0.3537914 | 0.000000 |
| V19 | 0.1328616 | 0.3582773 | 0.000000 |
| V47 | 0.1394219 | 0.3759677 | 0.000000 |
| V55 | 0.1589879 | 0.4287299 | 0.000000 |
| V26 | 0.1765625 | 0.4761218 | 0.000000 |
| V48 | 0.1779364 | 0.4798269 | 0.000000 |
| V29 | 0.2160511 | 0.5826078 | 0.000000 |
| V42 | 0.2333095 | 0.6291469 | 0.000000 |
| V6 | 0.2343866 | 0.6320516 | 0.000000 |
| V20 | 0.2397789 | 0.6465924 | 0.000000 |
| V1 | 0.2593621 | 0.6994010 | 0.000000 |
| V2 | 0.2652978 | 0.7154073 | 0.000000 |
| V54 | 0.3097143 | 0.8351818 | 0.000000 |
| V40 | 0.3216242 | 0.8672981 | 0.000000 |
| V60 | 0.3563664 | 0.9609847 | 0.000000 |
| V14 | 0.3572180 | 0.9632810 | 0.000000 |
| V3 | 0.3595286 | 0.9695119 | 0.000000 |
| V52 | 0.3948310 | 1.0647091 | 0.000000 |
| V58 | 0.4186773 | 1.1290135 | 0.000000 |
| V35 | 0.4268529 | 1.1510598 | 0.000000 |
| V15 | 0.4676095 | 1.2609650 | 0.000000 |
| V51 | 0.5138987 | 1.3857892 | 0.000000 |
| V21 | 0.5196423 | 1.4012777 | 0.000000 |
| V57 | 0.5348672 | 1.4423335 | 0.000000 |
| V18 | 0.5821234 | 1.5697653 | 0.000000 |
| V34 | 0.5993462 | 1.6162087 | 0.000000 |
| V49 | 0.7437472 | 2.0056034 | 0.000000 |
| V5 | 0.8256171 | 2.2263753 | 0.000000 |
| V33 | 0.9981225 | 2.6915566 | 0.000000 |
| V50 | 1.0826702 | 2.9195497 | 0.000000 |
| V43 | 1.2909897 | 3.4813082 | 0.000000 |
| V30 | 1.3041581 | 3.5168184 | 0.000000 |
| V23 | 1.3214433 | 3.5634300 | 0.000000 |
| V22 | 1.3849784 | 3.7347601 | 0.000000 |
| V28 | 1.6013217 | 4.3181557 | 0.000000 |
| V24 | 1.8528812 | 4.9965160 | 0.000000 |
| V25 | 2.0614003 | 5.5588126 | 0.000000 |
| V32 | 2.8817970 | 7.7711105 | 0.000000 |
| V44 | 3.0154479 | 0.8022216 | 4.319940 |
| V39 | 3.0934591 | 0.7034846 | 4.502128 |
| V37 | 3.3454581 | 0.8436842 | 4.820022 |
| V17 | 3.3606302 | 0.7580712 | 4.894597 |
| V46 | 3.3678171 | 0.9310958 | 4.804038 |
| V36 | 3.4872152 | 1.2259491 | 4.820022 |
| V7 | 3.5999626 | 1.9421882 | 4.577067 |
| V16 | 3.6023687 | 3.0291498 | 3.940228 |
| V27 | 3.9114160 | 2.7317811 | 4.606701 |
| V31 | 4.2438797 | 11.4441295 | 0.000000 |
| V45 | 4.6313000 | 0.8321081 | 6.870571 |
| V13 | 6.0729518 | 2.6020604 | 8.118721 |
| V9 | 6.1685352 | 0.6752631 | 9.406310 |
| V10 | 6.2532734 | 0.1409472 | 9.855923 |
| V12 | 7.1804369 | 0.3496059 | 11.206579 |
| V11 | 8.3466767 | 0.0154652 | 13.257153 |
(The columns each sum up to 100.)
caretStack allows us to move beyond simple blends of models to using
“meta-models” to ensemble collections of predictive models. DO NOT use
the trainControl object you used to fit the training models
to fit the ensemble. The re-sampling indexes will be wrong. Fortunately,
you don”t need to be fastidious with re-sampling indexes for caretStack,
as it only fits one model, and the defaults train uses will
usually work fine:
glm_ensemble <- caretStack(
model_list,
method="glm",
metric="ROC",
trControl=trainControl(
method="boot",
number=10,
savePredictions="final",
classProbs=TRUE,
summaryFunction=twoClassSummary
)
)
model_preds2 <- model_preds
model_preds2$ensemble <- predict(glm_ensemble, newdata=testing, type="prob")
CF <- coef(glm_ensemble$ens_model$finalModel)[-1]
colAUC(model_preds2, testing$Class)## glm rpart ensemble
## M vs. R 0.7137346 0.7746914 0.8140432
CF/sum(CF)## glm rpart
## 0.3708346 0.6291654
Note that glm_ensemble$ens_model is a regular caret
object of class train. The glm-weighted model weights (glm
vs rpart) and test-set AUCs are extremely similar to the caretEnsemble
greedy optimization.
We can also use more sophisticated ensembles than simple linear weights, but these models are much more susceptible to over-fitting, and generally require large sets of resamples to train on (n=50 or higher for bootstrap samples). Lets try one anyways:
library("gbm")
gbm_ensemble <- caretStack(
model_list,
method="gbm",
verbose=FALSE,
tuneLength=10,
metric="ROC",
trControl=trainControl(
method="boot",
number=10,
savePredictions="final",
classProbs=TRUE,
summaryFunction=twoClassSummary
)
)
model_preds3 <- model_preds
model_preds3$ensemble <- predict(gbm_ensemble, newdata=testing, type="prob")
colAUC(model_preds3, testing$Class)## glm rpart ensemble
## M vs. R 0.7137346 0.7746914 0.8009259
In this case, the sophisticated ensemble is no better than a simple weighted linear combination. Non-linear ensembles seem to work best when you have: