0

I am training a feed-forward network on a regression problem (MSE error) to predict a scalar value 1x1 given an input of size Nx1. My batch size is 400 (though this problem is robust to multiple batch sizes).

My problem is that the optimizer always converges to a solution where it predicts the same value for every element of the batch. For example, my real labels (for a batch of 400) might be:

[34, 50, 12, ... , 23, 45, 22] (size = 1x400)

And what it predicts will be something like this:

[29, 29, 29, .... , 29, 29, 29] (size = 1x400)

I think I get why it is doing this and it has to do with using MSE. Basically, MSE is computed over a whole batch and so that is the signal the optimizer receives to update the weights. The signal is not informative of the fact that predicting the same value for every element of the batch is a sub-optimal solution.

I have tried a couple solutions to this but all of them run into the same problem above. For example, I have tried adding a penalty for the variance prediction not matching the variance of the labels. Basically matching second moments. i.e. instead of,

$$ loss = (\hat y-y )^2$$ I do, $$ loss = (\hat y-y )^2+ (var(\hat y) -var(y) )^2 $$

And it still does the same thing.

Reading some other questions it seems that this is to be expected with MSE and I am just wondering if anyone has tried something different or a different variation of MSE or batching that solves this problem. I get that this is a difficult problem but given the limitations of softmax classification on a regression type problem I think it is a very worthwhile one to solve.

Anyone else encountered this or have a solution for it?

(I am using tensorflow and a setup with 4 GPUs if it makes a difference.)

sfortney
  • 115
  • 6
  • You said your network is predicting a scalar 1x1 but you also mentioned that the output size is 1x400. Which is it? The output should definitely not be a scalar. – shimao Sep 28 '17 at 02:38
  • 400 is the batch size. It is predicting a 1x1 scalar batch_size (400) times, hence a 1x400 output – sfortney Sep 28 '17 at 15:17

0 Answers0