caret: machine learning package in R

Introduction

R has a wide number of packages for machine learning (ML), which is great, but also quite frustrating since each package was designed independently and has very different syntax, inputs and outputs. Caret unifies these packages into a single package with constant syntax, saving everyone a lot of frustration and time! Caret stands for Classification And Regression Training.

Installing packages

1
2
library(tidyverse)
library(caret)

Core Functions

  • Data preparation (imputation, centering/scaling data, removing correlated predictors, reducing skewness)

  • Data splitting

  • Variable selection

  • Model evaluation

examples

Load data

1
2
3
library(pdp)
data(pima)
head(pima)
1
2
3
4
5
6
7
##   pregnant glucose pressure triceps insulin mass pedigree age diabetes
## 1 6 148 72 35 NA 33.6 0.627 50 pos
## 2 1 85 66 29 NA 26.6 0.351 31 neg
## 3 8 183 64 NA NA 23.3 0.672 32 pos
## 4 1 89 66 23 94 28.1 0.167 21 neg
## 5 0 137 40 35 168 43.1 2.288 33 pos
## 6 5 116 74 NA NA 25.6 0.201 30 neg

Data splitting

Before Creating trainData and testData, we need to remove the NA values, and then using createDataPartition with parameter p = 0.8 to split data into 80% trainData and 20% testData

1
2
3
4
5
6
7
8
df <- pima %>% na.omit() %>%
as.tbl() %>%
mutate(glucose=factor(ifelse(glucose > 143, "High", "Low")))

set.seed(123)
samples <- createDataPartition(df$diabetes, p = 0.8, list = F)
trainData <- df[samples, ]
testData <- df[-samples, ]

Training model

  • Resampling options(trainControl) : One of the most important part of training ML models is tuning parameters

  • Model parameter tuning options (tuneGrid): specify the tuning grid for model parameters

  • the trainControl is for resampling and the expand.grid is for model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
require(MLmetrics)

set.seed(123)

myControl = trainControl(method = "cv",
classProbs=T,
number = 5,
summaryFunction=prSummary,
verboseIter = FALSE)

gbm.grid <- expand.grid(interaction.depth = c(1, 2, 8),
n.trees = c(50, 100, 200, 250, 300),
shrinkage = 0.1,
n.minobsinnode = 20)
model_gbm <- train(diabetes ~.,
data = trainData,
method = "gbm",
trControl = myControl,
verbose = FALSE,
tuneGrid = gbm.grid,
metric = "AUC")
model_gbm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
## Stochastic Gradient Boosting 
##
## 314 samples
## 8 predictor
## 2 classes: 'neg', 'pos'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 252, 251, 251, 251, 251
## Resampling results across tuning parameters:
##
## interaction.depth n.trees AUC Precision Recall F
## 1 50 0.8650145 0.7996909 0.8666667 0.8307793
## 1 100 0.8736543 0.7954578 0.8428571 0.8175589
## 1 200 0.8744035 0.8069882 0.8333333 0.8184761
## 1 250 0.8680427 0.7903613 0.8428571 0.8150945
## 1 300 0.8662900 0.7858333 0.8333333 0.8081787
## 2 50 0.8765260 0.7781000 0.8333333 0.8046089
## 2 100 0.8667558 0.7917460 0.8142857 0.8026273
## 2 200 0.8693913 0.7804105 0.7952381 0.7876505
## 2 250 0.8646724 0.7916702 0.8095238 0.8002175
## 2 300 0.8656153 0.7831539 0.8238095 0.8026511
## 8 50 0.8731605 0.7887164 0.8285714 0.8074166
## 8 100 0.8735185 0.7773417 0.8142857 0.7948897
## 8 200 0.8687409 0.7986995 0.8238095 0.8103645
## 8 250 0.8666677 0.7897287 0.8238095 0.8051805
## 8 300 0.8600171 0.7956605 0.8142857 0.8033272
##
## Tuning parameter 'shrinkage' was held constant at a value of 0.1
## Tuning parameter 'n.minobsinnode' was held constant
## at a value of 20
## AUC was used to select the optimal model using the largest value.
## The final values used for the model were n.trees = 50, interaction.depth = 2, shrinkage = 0.1 and n.minobsinnode = 20.

The final values used for the model were n.trees = 50, interaction.depth = 2, shrinkage = 0.1

Model prediction

1
2
pred <- predict(model_gbm, newdata = testData)
confusionMatrix(pred, testData$diabetes)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
## Confusion Matrix and Statistics
##
## Reference
## Prediction neg pos
## neg 42 9
## pos 10 17
##
## Accuracy : 0.7564
## 95% CI : (0.646, 0.8465)
## No Information Rate : 0.6667
## P-Value [Acc > NIR] : 0.05651
##
## Kappa : 0.4571
##
## Mcnemar's Test P-Value : 1.00000
##
## Sensitivity : 0.8077
## Specificity : 0.6538
## Pos Pred Value : 0.8235
## Neg Pred Value : 0.6296
## Prevalence : 0.6667
## Detection Rate : 0.5385
## Detection Prevalence : 0.6538
## Balanced Accuracy : 0.7308
##
## 'Positive' Class : neg
##

The Accuracy of model_gbm is 0.754

Important features

1
2
require(gbm)
plot(varImp(model_gbm))

Running other machine methods for comparison

  • gbm

  • random forest

  • SVM

  • RDA

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
set.seed(123)

model_gbm <- train(diabetes ~.,
data = trainData,
method = "gbm",
trControl = myControl,
verbose = FALSE,
metric = "AUC")

model_rf <- train(diabetes ~.,
data = trainData,
method = "rf",
trControl = myControl,
verbose = FALSE,
metric = "AUC")

model_svm <- train(diabetes ~.,
data = trainData,
method = "svmRadial",
trControl = myControl,
verbose = FALSE,
metric = "AUC")

model_rda <- train(diabetes ~.,
data = trainData,
method = "rda",
trControl = myControl,
verbose = FALSE,
metric = "AUC")

Summary all the model fit

1
2
all.fit <- resamples(list(GBM=model_gbm, RF=model_rf, SVM=model_svm, RDA=model_rda))
summary(all.fit)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
## 
## Call:
## summary.resamples(object = all.fit)
##
## Models: GBM, RF, SVM, RDA
## Number of resamples: 5
##
## AUC
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## GBM 0.8380961 0.8446846 0.8734812 0.8730419 0.8897267 0.9192206 0
## RF 0.8634492 0.8698507 0.8727046 0.8749921 0.8756874 0.8932685 0
## SVM 0.8364399 0.8404596 0.8614867 0.8643218 0.8827897 0.9004330 0
## RDA 0.8304573 0.8697133 0.8726010 0.8789965 0.8994164 0.9227945 0
##
## F
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## GBM 0.8000000 0.8000000 0.8351648 0.8238231 0.8395062 0.8444444 0
## RF 0.7857143 0.8139535 0.8372093 0.8313310 0.8433735 0.8764045 0
## SVM 0.8089888 0.8131868 0.8372093 0.8290307 0.8387097 0.8470588 0
## RDA 0.8045977 0.8048780 0.8089888 0.8198274 0.8235294 0.8571429 0
##
## Precision
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## GBM 0.7500000 0.7755102 0.7906977 0.7959339 0.7916667 0.8717949 0
## RF 0.7857143 0.7954545 0.8181818 0.8165593 0.8297872 0.8536585 0
## SVM 0.7551020 0.7647059 0.7659574 0.7882313 0.8181818 0.8372093 0
## RDA 0.7659574 0.7777778 0.8139535 0.8079663 0.8250000 0.8571429 0
##
## Recall
## Min. 1st Qu. Median Mean 3rd Qu. Max. NA's
## GBM 0.8095238 0.8095238 0.8571429 0.8571429 0.9047619 0.9047619 0
## RF 0.7857143 0.8333333 0.8333333 0.8476190 0.8571429 0.9285714 0
## SVM 0.8571429 0.8571429 0.8571429 0.8761905 0.8809524 0.9285714 0
## RDA 0.7857143 0.8333333 0.8333333 0.8333333 0.8571429 0.8571429 0

ROC

1
2
3
4
5
6
7
8
9
10
library(pROC)

test_roc <- function(model, data){
res <- roc(data$diabetes, predict(model, data, type = "prob")[, "pos"])
return(res)
}

all_model <- list(GBM=model_gbm, RF=model_rf, SVM=model_svm, RDA=model_rda)
all_model_roc <- all_model %>% map(test_roc, data = testData)
all_model_roc %>% map(auc)
1
2
3
4
5
6
7
8
9
10
11
## $GBM
## Area under the curve: 0.8077
##
## $RF
## Area under the curve: 0.7837
##
## $SVM
## Area under the curve: 0.7722
##
## $RDA
## Area under the curve: 0.7885
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
df_roc <- c()
for(i in 1:length(all_model)){
a <- test_roc(all_model[[i]], testData)
b <- tibble(tpr=a$sensitivities,
fpr=1 - a$specificities,
model=names(all_model[i]))
df_roc <- rbind(df_roc, b)
}

ggplot(data = df_roc, aes(x=fpr, y=tpr, group=model))+
geom_line(aes(color=model), size=1)+
geom_abline(intercept = 0, slope = 1, color="grey", size = 1)+
theme_bw()+
labs(title = "ROC Curves for all models",
x = "False Positive Rate (1 - Specificity)",
y = "True Positive Rate (Sensivity or Recall)")

R information

1
sessionInfo()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
## R version 3.6.1 (2019-07-05)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
##
## Matrix products: default
##
## locale:
## [1] LC_COLLATE=Chinese (Simplified)_China.936 LC_CTYPE=Chinese (Simplified)_China.936
## [3] LC_MONETARY=Chinese (Simplified)_China.936 LC_NUMERIC=C
## [5] LC_TIME=Chinese (Simplified)_China.936
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] pROC_1.16.2 gbm_2.1.8 MLmetrics_1.1.1 pdp_0.7.0 caret_6.0-86 lattice_0.20-38 forcats_0.5.0
## [8] stringr_1.4.0 dplyr_1.0.2 purrr_0.3.3 readr_1.3.1 tidyr_1.0.0 tibble_3.0.4 ggplot2_3.3.2
## [15] tidyverse_1.3.0
##
## loaded via a namespace (and not attached):
## [1] colorspace_1.4-1 ellipsis_0.3.1 class_7.3-15 base64enc_0.1-3 fs_1.5.0
## [6] rstudioapi_0.10 farver_2.0.3 rstan_2.19.3 prodlim_2019.11.13 fansi_0.4.1
## [11] lubridate_1.7.9 xml2_1.2.2 codetools_0.2-16 splines_3.6.1 knitr_1.30
## [16] jsonlite_1.7.1 broom_0.7.2 kernlab_0.9-29 dbplyr_1.4.4 shiny_1.4.0
## [21] compiler_3.6.1 httr_1.4.2 backports_1.1.10 fastmap_1.0.1 assertthat_0.2.1
## [26] Matrix_1.2-18 cli_2.1.0 later_1.1.0.1 htmltools_0.5.0 prettyunits_1.1.1
## [31] tools_3.6.1 gtable_0.3.0 glue_1.4.2 reshape2_1.4.3 Rcpp_1.0.3
## [36] cellranger_1.1.0 vctrs_0.3.4 nlme_3.1-143 iterators_1.0.13 timeDate_3043.102
## [41] gower_0.2.2 xfun_0.19 ps_1.3.0 rvest_0.3.5 mime_0.9
## [46] miniUI_0.1.1.1 lifecycle_0.2.0 gtools_3.8.2 MASS_7.3-51.5 scales_1.1.0
## [51] ipred_0.9-9 promises_1.1.1 hms_0.5.3 parallel_3.6.1 inline_0.3.16
## [56] yaml_2.2.0 gridExtra_2.3 loo_2.3.1 StanHeaders_2.21.0-1 labelled_2.7.0
## [61] rpart_4.1-15 stringi_1.4.3 highr_0.8 klaR_0.6-15 foreach_1.5.1
## [66] randomForest_4.6-14 e1071_1.7-4 caTools_1.18.0 pkgbuild_1.1.0 lava_1.6.8
## [71] rlang_0.4.8 pkgconfig_2.0.3 matrixStats_0.57.0 bitops_1.0-6 evaluate_0.14
## [76] ROCR_1.0-7 labeling_0.4.2 recipes_0.1.14 processx_3.4.4 tidyselect_1.1.0
## [81] plyr_1.8.6 magrittr_1.5 R6_2.4.1 gplots_3.1.0 generics_0.0.2
## [86] combinat_0.0-8 DBI_1.1.0 pillar_1.4.6 haven_2.3.1 withr_2.1.2
## [91] survival_3.2-7 nnet_7.3-12 modelr_0.1.8 crayon_1.3.4 questionr_0.7.3
## [96] KernSmooth_2.23-16 rmarkdown_2.0 grid_3.6.1 readxl_1.3.1 data.table_1.13.2
## [101] blob_1.2.1 callr_3.5.1 ModelMetrics_1.2.2.2 reprex_0.3.0 digest_0.6.27
## [106] xtable_1.8-4 httpuv_1.5.4 stats4_3.6.1 munsell_0.5.0

Reference

  1. machine learning caret package

  2. cart package tutorial

  3. R 机器学习流程及案例实现

If the references cause any infringement issues, you could immediately contact to me via email.


------------- The End Thanks for reading --------