3

Set up

You have an input dataset X, and each row has multiple labels. Eg, 3 possible labels, [1,0,1] etc

Problem

The typical approach is to use BCEwithlogits loss or multi label soft margin loss. But what if the problem is now switched to all the labels must be correct, or don't predict anything at all? (Only predict if certain that all labels are correct)

What loss function do we pick for this? I thought of coding a custom loss function that returns 0 if all the labels match, else 1 but it seems “hacky”.

kjetil b halvorsen
  • 63,378
  • 26
  • 142
  • 467
Wboy
  • 157
  • 1
  • 6
  • Such a loss function (that can only be 0 or 1) would have a gradient of zero almost everywhere making it nearly impossible to optimize. – jodag Oct 16 '20 at 22:06
  • What would you recommend? @jodag – Wboy Oct 18 '20 at 07:15
  • You’ve specified the right loss function of what you really want is a model that only predicts a category when it is absolutely certain that the category is correct. I think you don’t want that. Remember that common “classification” techniques like logistic regression and neural network predict probabilities of class membership that the user then can translate into categories, depending on thresholds and an assessment of damage caused by misclassifications. – Dave Oct 18 '20 at 07:31
  • Of possible interest: https://stats.stackexchange.com/questions/464636/proper-scoring-rule-when-there-is-a-decision-to-make-e-g-spam-vs-ham-email and the linked answers (and the links in those, like Frank Harrell’s blog – Dave Oct 18 '20 at 07:33

1 Answers1

3

problem is now switched to all the labels must be correct, or don't predict anything at all?

You don't need any special loss function in here. You can simply use some regular loss, train the classifier to predict probabilities, and then, at inference time, make decision only if all the probabilities are higher than some threshold.

Tim
  • 108,699
  • 20
  • 212
  • 390