0

Why is this toy example so difficult for neural net to learn? My guess is that the output of the first hidden layer is not normalized, so propagated gradient is not very stable. I've tried adding BatchNormalization between the two linear layers, but it has no visible effect on the optimization.

UPD: It seems like this particular behavior is caused by not scaled target variable. This matches my experiments, which showed that the learning occurs only with very small learning rate.

Example code:

import numpy as np
import keras
from keras.models import Sequential
from keras.layers import Dense

x = np.arange(100)[:, np.newaxis]
f = lambda x: x ** 2 + 3
y = f(x)
x_normed = (x - x.mean()) / x.std()

km = Sequential()
km.add(Dense(1, activation=None))
km.add(Dense(1, activation=None))
km.compile('sgd', loss='mse')

km.fit(x_normed, y, epochs=300)
ptyshevs
  • 101
  • 3
  • 1
    you're fitting a quadratic with a linear model?? – shimao Jul 26 '19 at 01:41
  • What are you trying to achieve with this experiment? What do you mean by "difficult to learn"? A linear model will never be able to fit a quadratic function exactly... – Jan Kukacka Jul 26 '19 at 08:58
  • @shimao, yes, to show the clear discrepancy between the model that generated the data and the model that approximates the generation function given samples. – ptyshevs Jul 26 '19 at 09:52
  • @JanKukacka I'm trying to achieve a convergence of optimization process for given simple model. I understands that the exact fit is not possible, but it's not the point. In the example given, optimization blows up and loss is going to `nan` or `inf`. – ptyshevs Jul 26 '19 at 09:55
  • Try using lower learning rate then. Also, check https://stats.stackexchange.com/questions/352036/what-should-i-do-when-my-neural-network-doesnt-learn for general tips on debugging the learning process – Jan Kukacka Jul 26 '19 at 09:57
  • @JanKukacka thanks, that helped, but I am looking for more theoretical explanation of such behavior. Should I delete this question? – ptyshevs Jul 27 '19 at 12:25
  • 2
    i think the main problem why was hard to fit is you didn't normalize $y$, resulting in squared errors on the other of 10^8 – shimao Jul 28 '19 at 19:31
  • @shimao wow, I haven't think about it. I know situations when you need to transform target variable (for example log-transform of heavy-tailed target distribution), but I haven't seen normalization of target. Is it that common? – ptyshevs Jul 30 '19 at 07:41
  • yes it is standard – shimao Jul 30 '19 at 08:22

0 Answers0