3

I have an extremely imbalanced dataset (millions of times more negatives) for a binary classification NN model. I am aggressively downsampling solely for the purpose of making training time manageable, (not to be confused with downsampling in order to bias the model, make accuracy easier to interpret, etc. - I can fix these problems separately by adjusting the classification threshold). In other words, an unbiased sample of my data would require tens of millions of obs just to get a few positives - not ideal.

My understanding is that once you do this majority class downsampling, you are supposed to weight the loss function in order to "calibrate" probabilities. In other words, I am feeding in roughly balanced data, so the average NN prediction will be in the 0.5 ballpark. However, since the actual positivity rate is 0.0...01, the output probabilities should generally be much lower. This article describes this as upweighting, or calibrating the downsample.

To do this, I am using the class_weight argument of the tf.keras (tf 2) model.fit step, set to {False: 1/downsample_rate (very big number), True: 1}.

Excluding this argument, I get a very good model, high ROC AUC, except that the predicted probabilities are way too big (the NN thinks the data is balanced). Upon adding this argument, my ROC performance drops dramatically, and probabilities are still pretty big. My understanding is that the ROC isn't affected by weighting classes (since both axes are normalized), or simply scaling probabilities (it's based on rank of preds only), so it seems the model is actually getting worse.

Any thoughts or suggestions? Why is this happening, am I taking the correct approach, and is there a better way to "calibrate" probabilities in a NN after downsampling?

Note: I tried this same approach on a random forest, and I also (surprisingly) saw a small decrease in AUC by weighting the loss, but the difference was far less dramatic.

Paul Fornia
  • 226
  • 1
  • 6
  • 1
    One idea I had was that a different set of class weights could have a different set of optimal hyperparams. But I did a quick hyperopt with the class weights included, and the architecture didn't change much, and performance was still bad. – Paul Fornia Jan 15 '21 at 17:20
  • 1
    You could try just shifting the probabilities, e.g. https://stats.stackexchange.com/q/294494/232706 and https://datascience.stackexchange.com/q/58631/55122 I am surprised to hear that using `class_weight` would hurt AUC substantially, though. Can you share (some version of) the data and code? – Ben Reiniger Jan 16 '21 at 03:14
  • Thanks Ben, I'll try shifting the probabilities. I can't share my code unfortunately, but I'll try to find time to replicate the issue on public data, and will post an update here once I do. – Paul Fornia Jan 17 '21 at 00:50

1 Answers1

0

With unbalanced data, ROC is not a good representation of your model. PR is generally better.

With your extremely unbalanced data, you need to think carefully about what your quality goal actually is, then both measure it and optimise for it. I assume that you really dislike false negatives (otherwise the model that always answers "no" would be excellent). But what is your goal for precision? With such an unbalanced data set, any model with decent recall is likely to have low precision.

chrishmorris
  • 820
  • 5
  • 5
  • Let's say my goal is to maximize precision while recalling X% of my positives (basically PR AUC). Measured by this metric, or measured by log loss, I'm seeing the same pattern - class weights hurt the model dramatically. – Paul Fornia Jan 15 '21 at 17:11
  • By the way, can you elaborate on why ROC isn't good for unbalanced data? I've heard this a few times, but have never fully understood the argument. I understand that you can have high AUC and a "low" precision (as in my case), but the axes are normalized, so it's not "gameable" like accuracy is. Better AUC should still generally mean better model, no? – Paul Fornia Jan 15 '21 at 17:14
  • 1
    Why ROC is not useful for unbalanced data: http://pages.cs.wisc.edu/~jdavis/davisgoadrichcamera2.pdf – chrishmorris Jan 18 '21 at 08:14
  • thanks @chrishmorris! – Paul Fornia Jan 19 '21 at 22:03