I have been trying to find an answer to this question for some time. I understand that cross-validation is primarily used for model selection, i.e. to tune parameters/hyperparameters, but I don’t quite understand why you wouldn’t use the same procedure to build a final model.
I have looked at the following questions/answers, and while they have helped me understand the issue, they simply repeat the adage that "you shouldn't use cross-validation to build a final model" without explaining WHY.
- How to choose a predictive model after k-fold cross-validation?
- Feature selection for "final" model when performing cross-validation in machine learning
- Training with the full dataset after cross-validation?
- Nested cross validation for model selection
- How to build the final model and tune probability threshold after nested cross-validation?
In particular, why not use cross-validation to consecutively update model weights and prevent overfitting? Consider the following scenario:
Task: Text classification
Training/validation dataset: 60,000 sentences labeled with one or more of three labels, A, B, C.
Note: This is thus a multilabel classification task in which all the labels are unbalanced - there are roughly 6,000 of each label in the whole dataset.
Test dataset: Separate dataset held out for testing, roughly 1,500 sentences.
Note: The datasets have similar distributions of label imbalance, word type and sentence length.
Procedure: In training/parameter testing, use k-fold cross-validation so that the network sees all the data in the training set, with validation loss/accuracy as the measure of the model performance. This returns a set of K models whose validation scores give a measure of how well the model with a particular set of parameters will perform when trained on the whole dataset. Each model K can then be tested on the held-out dataset to assess its performance on real-world data (i.e. nested cross-validation).
At this point three things seem reasonable:
use the above procedure to identify the best hyperparameters and train a new model on all the data, without cross-validation.
re-run the best model using cross-validation.
when I am running the above procedure, rather than getting the mean of the cross-validation results as an indication of model performance, allow each CV fold to update the model weights and save the best model that results from this procedure, evaluating this model on the held-out dataset. [edit: this is possible in Python > Keras, for example]
My intuition is that 2 or 3 will give a model trained on the whole dataset that is less likely to be overfitted than a model trained without cross-validation, thus more generalizable to real-world. I do understand that overfitting should be combated by other measures (which I use), but my experience is that models built using cross-validation are less overfitted than those without. Also, by checkpointing (3) I avoid having to build a new model with a given set of parameters, saving me an extra step and a large amount of computational resources.
One drawback of 3 that I can see is that it gives me a near-upper-bound of model performance, while strict cross-validation (with means) gives a lower bound. So this is a trade-off, and it means that following procedure #1 can give better model training on the full dataset. But better training on the full dataset does not necessarily lead to better generalizations (just check the public vs private leaderboards on Kaggle competitions), and process #3 seems to achieve better generalization.
The only other thing I can think of is that since you don’t control which sentences/labels are in each fold, updating weights may give your model erroneous weights depending on how the data in that particular fold is structured. However, I would think this can be mitigated by using Stratified k-folds and/or Repeated Stratified k-folds. Again, ultimately, performance is measured by prediction on the held-out dataset.
Is there something I’m missing here? Are there papers that compare models built during cross-validation vs models built after cross-validation?