0

I was wondering if someone could provide some insight on the pros/cons of using built-in cross validation functions like cv.glmnet (https://www.rdocumentation.org/packages/glmnet/versions/3.0-2/topics/cv.glmnet), as opposed to constructing a manually coded cross validation scheme with caret.

In particular, I'm aiming to do a stratified cross-validation to ensure equivalent proportions of the response variable. I'm not sure whether this is possible with built in cv functions, but any insight would be much appreciated.

StupidWolf
  • 4,494
  • 3
  • 10
  • 26

1 Answers1

1

I think this is where it differs:

1.the loss function, or the metric used to decide on the best parameter is restricted to deviance in model (deviance), misclassification error or 1-Accuracy (class) and AUC 'AUC'. For caret, you can use those above and also kappa cohen, precision etc.

2.In terms of stratified cross-validation, this is not a real problem. You can generate the folds using caret and feed it in cv.glmnet:

library(caret)
library(glmnet)
data  = iris
data$Species=as.numeric(data$Species=="versicolor")
dataFolds = createFolds(factor(data$Species),5)
fold_id = rep(1:length(dataFolds),sapply(dataFolds,length))

mdl1 = cv.glmnet(x=as.matrix(data[,1:4]),y=data[,5],alpha=1,
foldid = fold_id[order(unlist(dataFolds))],measure="class")

3.cv.glmnet will choose the lambda that is 1se from the lambda with the least error as the optimal lambda. See this post.

4.you cannot tweak vary alpha with cv.glmnet , meaning you will have to run cv.glmnet with multiple runs of alpha

5.speed. cv.glmnet runs faster than caret if you have a large dataset, because it does not store as much information as caret, for example:

library(microbenchmark)
fit_cv = function(){
cv.glmnet(x=as.matrix(data[,1:4]),y=data[,5],alpha=1,
foldid = fold_id[order(unlist(dataFolds))],measure="class")
}
fit_caret = function(){
train(x=data[,1:4],y=factor(data[,5]),data=data,method="glmnet",family="binomial",
tuneGrid=G,trControl=trainControl(method="cv",index=dataFolds))
}

microbenchmark(fit_cv,fit_caret,times=10)
Unit: nanoseconds
      expr min  lq  mean median  uq  max neval cld
    fit_cv 131 173 379.3  324.0 581  877    10   a
 fit_caret 132 263 550.1  440.5 587 1342    10   a

This will only increase as your dataset gets larger

StupidWolf
  • 4,494
  • 3
  • 10
  • 26